def resolve_to_device_spec(self, device): """Resolve an AutoDist DeviceSpec or its string to a TensorFlow DeviceSpec.""" if isinstance(device, (list, set)): return type(device)(self.resolve_to_device_spec(d) for d in device) if isinstance(device, str): device = DeviceSpec.from_string(device) t = self._address_to_tasks.get(device.host_address)[0] # temporarily only use the first one return device_spec.DeviceSpecV2( job=t['job'], # temporarily not fully resolving it before memory issue solved task=t['task'], device_type=device.device_type.name, device_index=device.device_index )
def _create_proxy(self, graph_item, gradient, target): # Do not replicate sparse variables if not isinstance(gradient, ops.Tensor) \ or self.worker_device in self.target_device: # meaning the variable is local return None d = device_spec.DeviceSpecV2.from_string(self.worker_device) master_var = graph_item.trainable_var_op_to_var.get(target.op) master_var_device = device_spec.DeviceSpecV2.from_string( master_var.device) device_type = 'GPU' \ if master_var_device.device_type and master_var_device.device_type.upper() == 'GPU' \ else 'CPU' device_index = 0 if device_type == 'CPU' else master_var_device.device_index proxy_var_device = device_spec.DeviceSpecV2(job=d.job, replica=d.replica, task=d.task, device_type=device_type, device_index=device_index) return ProxyVariable(master_var, graph_item, proxy_var_device)
def test_merge_removed(self): with self.assertRaises(AttributeError): d = device_spec.DeviceSpecV2() d.merge_from(device_spec.DeviceSpecV2.from_string("/task:1/cpu:0"))