Beispiel #1
0
 def resolve_to_device_spec(self, device):
     """Resolve an AutoDist DeviceSpec or its string to a TensorFlow DeviceSpec."""
     if isinstance(device, (list, set)):
         return type(device)(self.resolve_to_device_spec(d) for d in device)
     if isinstance(device, str):
         device = DeviceSpec.from_string(device)
     t = self._address_to_tasks.get(device.host_address)[0]  # temporarily only use the first one
     return device_spec.DeviceSpecV2(
         job=t['job'],  # temporarily not fully resolving it before memory issue solved
         task=t['task'],
         device_type=device.device_type.name,
         device_index=device.device_index
     )
 def _create_proxy(self, graph_item, gradient, target):
     # Do not replicate sparse variables
     if not isinstance(gradient, ops.Tensor) \
             or self.worker_device in self.target_device:  # meaning the variable is local
         return None
     d = device_spec.DeviceSpecV2.from_string(self.worker_device)
     master_var = graph_item.trainable_var_op_to_var.get(target.op)
     master_var_device = device_spec.DeviceSpecV2.from_string(
         master_var.device)
     device_type = 'GPU' \
         if master_var_device.device_type and master_var_device.device_type.upper() == 'GPU' \
         else 'CPU'
     device_index = 0 if device_type == 'CPU' else master_var_device.device_index
     proxy_var_device = device_spec.DeviceSpecV2(job=d.job,
                                                 replica=d.replica,
                                                 task=d.task,
                                                 device_type=device_type,
                                                 device_index=device_index)
     return ProxyVariable(master_var, graph_item, proxy_var_device)
Beispiel #3
0
 def test_merge_removed(self):
     with self.assertRaises(AttributeError):
         d = device_spec.DeviceSpecV2()
         d.merge_from(device_spec.DeviceSpecV2.from_string("/task:1/cpu:0"))