Exemplo n.º 1
0
  def device_function(self, op):
    """Chose a device for `op`.

    Args:
      op: an `Operation`.

    Returns:
      The device to use for the `Operation`.
    """
    if not self._merge_devices and op.device:
      return op.device
    current_device = pydev.from_string(op.device or "")
    spec = pydev.Device()
    if self._ps_tasks and self._ps_device:
      node_def = op if isinstance(op, graph_pb2.NodeDef) else op.node_def
      if node_def.op in self._ps_ops:
        device_string = "%s/task:%d" % (self._ps_device, self._next_task())
        if self._merge_devices:
          spec = pydev.from_string(device_string)
          spec.merge_from(current_device)
          return spec.to_string()
        else:
          return device_string
    if self._worker_device:
      if not self._merge_devices:
        return self._worker_device
      spec = pydev.from_string(self._worker_device)

    if not self._merge_devices:
      return ""

    spec.merge_from(current_device)
    return spec.to_string()
Exemplo n.º 2
0
    def device_function(self, op):
        """Chose a device for `op`.

    Args:
      op: an `Operation`.

    Returns:
      The device to use for the `Operation`.
    """
        if not self._merge_devices and op.device:
            return op.device
        current_device = pydev.from_string(op.device or "")
        spec = pydev.Device()
        if self._ps_tasks and self._ps_device:
            node_def = op if isinstance(op, graph_pb2.NodeDef) else op.node_def
            if node_def.op in self._ps_ops:
                device_string = "%s/task:%d" % (self._ps_device,
                                                self._next_task())
                if self._merge_devices:
                    spec = pydev.from_string(device_string)
                    spec.merge_from(current_device)
                    return spec.to_string()
                else:
                    return device_string
        if self._worker_device:
            if not self._merge_devices:
                return self._worker_device
            spec = pydev.from_string(self._worker_device)

        if not self._merge_devices:
            return ""

        spec.merge_from(current_device)
        return spec.to_string()
Exemplo n.º 3
0
  def testFromString(self):
    d = device.from_string("/job:foo/replica:0")
    self.assertEquals("/job:foo/replica:0", d.to_string())
    with self.assertRaises(Exception) as e:
      d = device.from_string("/job:muu/gpu:2/cpu:0")
    self.assertTrue("Cannot specify multiple device" in e.exception.message)

    d = device.from_string("/job:foo/replica:0/task:3/cpu:*")
    self.assertEquals(None, d.device_index)
    d = device.from_string("/job:foo/replica:0/task:3/gpu:7")
    self.assertEquals(7, d.device_index)
    d = device.from_string("/job:foo/replica:0/task:3/device:GPU:7")
    self.assertEquals(7, d.device_index)
Exemplo n.º 4
0
    def testFromString(self):
        d = device.from_string("/job:foo/replica:0")
        self.assertEquals("/job:foo/replica:0", d.to_string())
        with self.assertRaises(Exception) as e:
            d = device.from_string("/job:muu/gpu:2/cpu:0")
        self.assertTrue("Cannot specify multiple device" in str(e.exception))

        d = device.from_string("/job:foo/replica:0/task:3/cpu:*")
        self.assertEquals(None, d.device_index)
        d = device.from_string("/job:foo/replica:0/task:3/gpu:7")
        self.assertEquals(7, d.device_index)
        d = device.from_string("/job:foo/replica:0/task:3/device:GPU:7")
        self.assertEquals(7, d.device_index)
Exemplo n.º 5
0
def pin_variables_on_cpu(op):
    """Returns a CPU device for Variable nodes if the device is not specified.

  Args:
    op: The ops.Operation object describing the node for which a device
      should be chosen. The op.device field is respected.

  Returns:
    A device containing "/device:CPU:0" if the node is related to a variable.
  """
    device = op.device if op.device is not None else ""
    dev = pydev.from_string(device)

    # If a device type exists already, do not override.
    if dev.device_type:
        return device

    if isinstance(op, ops.Operation):
        node_def = op.node_def
    else:
        assert isinstance(op, graph_pb2.NodeDef)
        node_def = op

    if _is_variable_op(node_def.op):
        return set_cpu0(device)
    return device
Exemplo n.º 6
0
def pin_variables_on_cpu(op):
    """Returns a CPU device for Variable nodes if the device is not specified.

  Args:
    op: The ops.Operation object describing the node for which a device
      should be chosen. The op.device field is respected.

  Returns:
    A device containing "/device:CPU:0" if the node is related to a variable.
  """
    device = op.device if op.device is not None else ""
    dev = pydev.from_string(device)

    # If a device type exists already, do not override.
    if dev.device_type:
        return device

    if isinstance(op, ops.Operation):
        node_def = op.node_def
    else:
        assert isinstance(op, graph_pb2.NodeDef)
        node_def = op

    if _is_variable_op(node_def.op):
        return set_cpu0(device)
    return device
Exemplo n.º 7
0
def set_cpu0(device_string):
  """Creates a new device string based on `device_string' but using /CPU:0.
   If the device is already on /CPU:0, this is a no-op.
   Args:
     device_string: A device string.
   Returns:
     A device string.
  """
  parsed_device = pydev.from_string(device_string)
  parsed_device.device_type = "CPU"
  parsed_device.device_index = 0
  return parsed_device.to_string()
Exemplo n.º 8
0
def pin_to_cpu(op):
  """Returns a CPU device for the given node."""
  device = op.device if op.device is not None else ""
  dev = pydev.from_string(device)

  if not dev.device_type:
    return set_cpu0(device)
  if dev.device_type == "CPU":
    return device

  logging.info("Operation %s has been assigned to a non-CPU (%s), so "
               "it will not be pinned to the CPU.", op.name, dev.device_type)
  return device
Exemplo n.º 9
0
def pin_to_cpu(op):
  """Returns a CPU device for the given node."""
  device = op.device if op.device is not None else ""
  dev = pydev.from_string(device)

  if not dev.device_type:
    return set_cpu0(device)
  if dev.device_type == "CPU":
    return device

  logging.info("Operation %s has been assigned to a non-CPU (%s), so "
               "it will not be pinned to the CPU.", op.name, dev.device_type)
  return device
Exemplo n.º 10
0
def set_cpu0(device_string):
    """Creates a new device string based on `device_string' but using /CPU:0.

   If the device is already on /CPU:0, this is a no-op.

   Args:
     device_string: A device string.

   Returns:
     A device string.
  """
    parsed_device = pydev.from_string(device_string)
    parsed_device.device_type = "CPU"
    parsed_device.device_index = 0
    return parsed_device.to_string()
Exemplo n.º 11
0
  def testMerge(self):
    d = device.from_string("/job:foo/replica:0")
    self.assertEquals("/job:foo/replica:0", d.to_string())
    d.merge_from(device.from_string("/task:1/gpu:2"))
    self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())

    d = device.Device()
    d.merge_from(device.from_string("/task:1/cpu:0"))
    self.assertEquals("/task:1/device:CPU:0", d.to_string())
    d.merge_from(device.from_string("/job:boo/gpu:0"))
    self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
    d.merge_from(device.from_string("/job:muu/cpu:2"))
    self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
    d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2"))
    self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
Exemplo n.º 12
0
    def testMerge(self):
        d = device.from_string("/job:foo/replica:0")
        self.assertEquals("/job:foo/replica:0", d.to_string())
        d.merge_from(device.from_string("/task:1/gpu:2"))
        self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2",
                          d.to_string())

        d = device.Device()
        d.merge_from(device.from_string("/task:1/cpu:0"))
        self.assertEquals("/task:1/device:CPU:0", d.to_string())
        d.merge_from(device.from_string("/job:boo/gpu:0"))
        self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
        d.merge_from(device.from_string("/job:muu/cpu:2"))
        self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
        d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2"))
        self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2",
                          d.to_string())