def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() #检查是否有可用的GPU device_type = _get_available_device_type() if device_type is None: self.module = module self.device_ids = [] return #默认使用所有可见的GPU if device_ids is None: device_ids = _get_all_device_indices() #默认server是device_ids列表上的第一个 if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = [_get_device_index(x, True) for x in device_ids] self.output_device = _get_device_index(output_device, True) self.src_device_obj = torch.device(device_type, self.device_ids[0]) #检查负载是否平衡,不平衡(指内存或者处理器 max/min > 0.75 会有警告) _check_balance(self.device_ids) #单卡 if len(self.device_ids) == 1: self.module.to(self.src_device_obj)
def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() torch._C._log_api_usage_once("torch.nn.parallel.DataParallel") device_type = _get_available_device_type() if device_type is None: self.module = module self.device_ids = [] return if device_ids is None: device_ids = _get_all_device_indices() if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = [_get_device_index(x, True) for x in device_ids] self.output_device = _get_device_index(output_device, True) self.src_device_obj = torch.device(device_type, self.device_ids[0]) _check_balance(self.device_ids) if len(self.device_ids) == 1: self.module.to(self.src_device_obj)
def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() device_type = _get_available_device_type() if device_type is None: self.module = module self.device_ids = [] return if device_ids is None: device_ids = _get_all_device_indices() if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.src_device_obj = torch.device(device_type, self.device_ids[0]) _check_balance(self.device_ids) if len(self.device_ids) == 1: self.module.to(self.src_device_obj)
def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None): r"""Evaluates module(input) in parallel across the GPUs given in device_ids. This is the functional version of the DataParallel module. Args: module (Module): the module to evaluate in parallel inputs (Tensor): inputs to the module device_ids (list of int or torch.device): GPU ids on which to replicate module output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. (default: device_ids[0]) Returns: a Tensor containing the result of module(input) located on output_device """ if not isinstance(inputs, tuple): inputs = (inputs, ) if inputs is not None else () device_type = _get_available_device_type() if device_ids is None: device_ids = _get_all_device_indices() if output_device is None: output_device = device_ids[0] device_ids = [_get_device_index(x, True) for x in device_ids] output_device = _get_device_index(output_device, True) src_device_obj = torch.device(device_type, device_ids[0]) for t in chain(module.parameters(), module.buffers()): if t.device != src_device_obj: raise RuntimeError("module must have its parameters and buffers " "on device {} (device_ids[0]) but found one of " "them on device: {}".format( src_device_obj, t.device)) inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) # for module without any inputs, empty list and dict will be created # so the module can be executed on one device which is the first one in device_ids if not inputs and not module_kwargs: inputs = ((), ) module_kwargs = ({}, ) if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0]) used_device_ids = device_ids[:len(inputs)] replicas = replicate(module, used_device_ids) outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) return gather(outputs, output_device, dim)