def collect_gradients(self): if self.ngradev > 1: # in case some parameters might not be used during the forward propagation on some GPUs: p.data.new_zeros(p.data.size()) if p.grad is None else p.grad instead of p.grad, but in most cases, this can warn you in case you miss the use of some parameters in the forward computation. grads = comm.reduce_add_coalesced([[p.grad for p in filter_para_grad(net.parameters())] for net in self.nets[:self.ngradev]], self.output_device)# if self.ngradev > 1 else [p.grad for p in filter_para_grad(self.nets[0].parameters())] for mp, grad in zip(filter_para_grad(self.module.parameters()), grads): mp.grad = grad
def reset_grad(self): for para in filter_para_grad(self.module.parameters()): para.grad = None if self.nets is not None and self.ngradev > 1: for net in self.nets[1:self.ngradev]: for para in filter_para_grad(net.parameters()): para.grad = None self.ngradev = 0
def collect_gradients_func(self, func): if self.ngradev > 1: grads = comm.reduce_add_coalesced( [[p.grad for p in filter_para_grad(func(net).parameters())] for net in self.nets[:self.ngradev]], self.output_device) for mp, grad in zip( filter_para_grad(func(self.module).parameters()), grads): mp.grad = grad
def zero_replicas_grad(self, func=None): if self.nets is not None and self.ngradev > 1: if func is None: for net in self.nets[1:self.ngradev]: for para in filter_para_grad(net.parameters()): para.grad = None else: for net in self.nets[1:self.ngradev]: for para in filter_para_grad(func(net).parameters()): para.grad = None
def update_replicas_para(self, parallel=False): params = [para.data for para in filter_para_grad(self.module.parameters())] if len(params) > 0: param_copies = tuple([t for tensors in comm.broadcast_coalesced(params, self.device_ids) for t in tensors]) #for module, param_copy in zip(self.nets, [param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params))]): for module, param_copy in zip(self.nets[1:], [param_copies[i:i + len(params)] for i in range(len(params), len(param_copies), len(params))]): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data = para self.ngradev = 0
def update_replicas(self, parallel=False): params = [para.data for para in filter_para_grad(self.module.parameters())] if len(params) > 0: param_copies = tuple([t for tensors in comm.broadcast_coalesced(params, self.device_ids) for t in tensors]) # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient #for module, param_copy in zip(self.nets, [param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params))]): for module, param_copy in zip(self.nets[1:], [param_copies[i:i + len(params)] for i in range(len(params), len(param_copies), len(params))]): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data, mp.grad = para, None self.ngradev = 0
def allocate(self, parameters=None, init_tensors=None): self.pll = self.pll if parameters is None else [filter_para_grad(pl) for pl in parameters] cpl = [] if init_tensors is None: for pl in self.pll: if len(pl) > 1: _numel = sum(para.numel() for para in pl) _weight = nn.Parameter(pl[0].new_empty(_numel)) _weight.grad = pl[0].new_zeros(_numel) cpl.append(_weight) else: _weight = pl[0] if _weight.grad is None: _weight.grad = _weight.new_zeros(_weight.size()) cpl.append(_weight) else: for pl, init_tensor in zip(self.pll, init_tensors): if len(pl) > 1: _numel = sum(para.numel() for para in pl) if init_tensor is None else init_tensor.numel() _weight = nn.Parameter(pl[0].new_empty(_numel) if init_tensor is None else init_tensor) _weight.grad = pl[0].new_zeros(init_tensor.numel()) if (init_tensor is None) or (init_tensor.grad is None) else init_tensor.grad cpl.append(_weight) else: _weight = pl[0] if _weight.grad is None: _weight.grad = _weight.new_zeros(_weight.size()) if (init_tensor is None) or (init_tensor.grad is None) else init_tensor.grad.view(_weight.size()) cpl.append(_weight) self.weights = nn.ParameterList(cpl)
def zero_grad(self, set_to_none=True): if self.is_contiguous_parameters: with torch.no_grad(): for para in get_all_contiguous_parameters_m(self.module): para.grad.zero_() if self.nets is not None and self.ngradev > 1: for net in self.nets[1:self.ngradev]: for para in get_all_contiguous_parameters_m(net): para.grad.zero_() else: for para in filter_para_grad(self.module.parameters()): para.grad = None if self.nets is not None and self.ngradev > 1: for net in self.nets[1:self.ngradev]: for para in filter_para_grad(net.parameters()): para.grad = None self.ngradev = 0
def build_optimizer(self, optm_func, *optm_args, multi_gpu_optimizer=False, contiguous_parameters=False, **optm_kwargs): self.is_contiguous_parameters = contiguous_parameters paras = filter_para_grad(self.module.parameters()) if (not multi_gpu_optimizer) or self.nets is None or (len(paras) < 2): if contiguous_parameters: if self.nets is not None: for net in self.nets[1:]: get_contiguous_parameters_m(net) _mp = get_contiguous_parameters_m(self.module) else: _mp = self.module.parameters() return optm_func(_mp, *optm_args, **optm_kwargs) else: self.optm_splt, _np = divide_para_ind(paras, len(self.device_ids), return_np=True) if contiguous_parameters: for net in self.nets: get_contiguous_parameters_p([ list( range_parameter_iter( net, lind, rind, func=filter_para_grad_iter)) for lind, rind in self.optm_splt ], model=net) optml = [ optm_func(get_contiguous_parameters_m(net, index=i), *optm_args, **optm_kwargs) for i, net in enumerate(self.nets) ] else: optml = [ optm_func( range_parameter_iter(net, lind, rind, func=filter_para_grad_iter), *optm_args, **optm_kwargs) for net, ( lind, rind, ) in zip(self.nets, self.optm_splt) ] # sort the optimizers with slightly more parameters ahead to start their optimization steps earlier optml, _device_ids = reorder_by_sort(_np, optml, self.device_ids[:len(optml)], reverse=True) return MultiGPUOptimizer(optml, device_ids=_device_ids)
def zero_replicas_grad(self, func=None): if self.nets is not None and self.ngradev > 1: if func is None: if self.is_contiguous_parameters: for net in self.nets[1:self.ngradev]: for para in get_all_contiguous_parameters_m(net): para.grad.zero_() else: for net in self.nets[1:self.ngradev]: for para in filter_para_grad(net.parameters()): para.grad = None else: if self.is_contiguous_parameters: for net in self.nets[1:self.ngradev]: for para in filter_para_grad(func(net).parameters()): para.grad.zero_() else: for net in self.nets[1:self.ngradev]: for para in filter_para_grad(func(net).parameters()): para.grad = None
def zero_replicas_grad(self): if self.nets is not None and self.ngradev > 1: for net in self.nets[1:self.ngradev]: for para in filter_para_grad(net.parameters()): para.grad = None
def update_replicas_para(self): if self.optm_splt is None: if self.is_contiguous_parameters: params = [ para.data for para in get_all_contiguous_parameters_m(self.module) ] param_copies = comm.broadcast_coalesced( params, self.device_ids) with torch.no_grad(): for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip( get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) else: params = [ para.data for para in filter_para_grad(self.module.parameters()) ] param_copies = comm.broadcast_coalesced( params, self.device_ids) for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data = para else: for i, ( net, ( lind, rind, ), ) in enumerate(zip(self.nets, self.optm_splt)): _dev_params = [ para.data for para in get_contiguous_parameters_m(net, index=i) ] if self.is_contiguous_parameters else [ para.data for para in range_parameter_iter( net, lind, rind, func=filter_para_grad_iter) ] if i > 0: _devices = self.device_ids[:] _devices.insert(0, _devices.pop(i)) else: _devices = self.device_ids _dev_param_copies = comm.broadcast_coalesced( _dev_params, _devices) if i > 0: _dev_param_copies.insert(i, _dev_param_copies.pop(0)) for pc, _dpc in zip(param_copies, _dev_param_copies): pc.extend(_dpc) else: param_copies = _dev_param_copies if self.is_contiguous_parameters: with torch.no_grad(): for module, param_copy in zip(self.nets, param_copies): for mp, para in zip( get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) else: for module, param_copy in zip(self.nets, param_copies): for mp, para in zip( filter_para_grad(module.parameters()), param_copy): mp.data = para self.ngradev = 0
def update_replicas(self): if self.optm_splt is None: if self.is_contiguous_parameters: params = [ para.data for para in get_all_contiguous_parameters_m(self.module) ] param_copies = comm.broadcast_coalesced( params, self.device_ids) with torch.no_grad(): for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip( get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) mp.grad.zero_() else: params = [ para.data for para in filter_para_grad(self.module.parameters()) ] param_copies = comm.broadcast_coalesced( params, self.device_ids) # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient #for module, param_copy in zip(self.nets, param_copies): for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data, mp.grad = para, None else: for i, ( net, ( lind, rind, ), ) in enumerate(zip(self.nets, self.optm_splt)): _dev_params = [ para.data for para in get_contiguous_parameters_m(net, index=i) ] if self.is_contiguous_parameters else [ para.data for para in range_parameter_iter( net, lind, rind, func=filter_para_grad_iter) ] if i > 0: _devices = self.device_ids[:] _devices.insert(0, _devices.pop(i)) else: _devices = self.device_ids _dev_param_copies = comm.broadcast_coalesced( _dev_params, _devices) if i > 0: _dev_param_copies.insert(i, _dev_param_copies.pop(0)) for pc, _dpc in zip(param_copies, _dev_param_copies): pc.extend(_dpc) else: param_copies = _dev_param_copies if self.is_contiguous_parameters: with torch.no_grad(): for module, param_copy in zip(self.nets, param_copies): for mp, para in zip( get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) mp.grad.zero_() else: for module, param_copy in zip(self.nets, param_copies): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data, mp.grad = para, None self.ngradev = 0
def collect_gradients(self): if self.optm_splt is not None: if self.is_contiguous_parameters: for i, ( net, device, ) in enumerate(zip(self.nets, self.device_ids)): _dev_grads = [[ para.grad for para in get_contiguous_parameters_m(_net, index=i) ] for _net in self.nets[:self.ngradev]] if i > 0: _dev_grads.insert( 0, _dev_grads.pop(i) if i < self.ngradev else [ para.grad for para in get_contiguous_parameters_m(net, index=i) ]) _dev_grads = comm.reduce_add_coalesced(_dev_grads, device) for mp, grad in zip( get_contiguous_parameters_m(net, index=i), _dev_grads): mp.grad.copy_(grad) else: grads = [[ para.grad for para in filter_para_grad(net.parameters()) ] for net in self.nets[:self.ngradev]] for i, ( net, device, ( lind, rind, ), ) in enumerate(zip(self.nets, self.device_ids, self.optm_splt)): _dev_grads = [gradu[lind:rind] for gradu in grads] if i > 0: _dev_grads.insert( 0, _dev_grads.pop(i) if i < self.ngradev else [ _pg.new_zeros(_pg.size(), device=device) for _pg in _dev_grads[0] ]) _dev_grads = comm.reduce_add_coalesced(_dev_grads, device) for mp, grad in zip( range_parameter_iter(net, lind, rind, func=filter_para_grad_iter), _dev_grads): mp.grad = grad elif self.ngradev > 1: if self.is_contiguous_parameters: grads = comm.reduce_add_coalesced([[ para.grad for para in get_all_contiguous_parameters_m(net) ] for net in self.nets[:self.ngradev]], self.output_device) for mp, grad in zip( get_all_contiguous_parameters_m(self.module), grads): mp.grad.copy_(grad) else: # in case some parameters might not be used during the forward propagation on some GPUs: p.data.new_zeros(p.data.size()) if p.grad is None else p.grad instead of p.grad, but in most cases, this can warn you in case you miss the use of some parameters in the forward computation. grads = comm.reduce_add_coalesced( [[ para.grad for para in filter_para_grad(net.parameters()) ] for net in self.nets[:self.ngradev]], self.output_device ) # if self.ngradev > 1 else [p.grad for p in filter_para_grad(self.nets[0].parameters())] for mp, grad in zip(filter_para_grad(self.module.parameters()), grads): mp.grad = grad