Beispiel #1
0
	def collect_gradients(self):

		if self.ngradev > 1:
			# in case some parameters might not be used during the forward propagation on some GPUs: p.data.new_zeros(p.data.size()) if p.grad is None else p.grad instead of p.grad, but in most cases, this can warn you in case you miss the use of some parameters in the forward computation.
			grads = comm.reduce_add_coalesced([[p.grad for p in filter_para_grad(net.parameters())] for net in self.nets[:self.ngradev]], self.output_device)# if self.ngradev > 1 else [p.grad for p in filter_para_grad(self.nets[0].parameters())]
			for mp, grad in zip(filter_para_grad(self.module.parameters()), grads):
				mp.grad = grad
Beispiel #2
0
    def reset_grad(self):

        for para in filter_para_grad(self.module.parameters()):
            para.grad = None
        if self.nets is not None and self.ngradev > 1:
            for net in self.nets[1:self.ngradev]:
                for para in filter_para_grad(net.parameters()):
                    para.grad = None
        self.ngradev = 0
Beispiel #3
0
    def collect_gradients_func(self, func):

        if self.ngradev > 1:
            grads = comm.reduce_add_coalesced(
                [[p.grad for p in filter_para_grad(func(net).parameters())]
                 for net in self.nets[:self.ngradev]], self.output_device)
            for mp, grad in zip(
                    filter_para_grad(func(self.module).parameters()), grads):
                mp.grad = grad
Beispiel #4
0
    def zero_replicas_grad(self, func=None):

        if self.nets is not None and self.ngradev > 1:
            if func is None:
                for net in self.nets[1:self.ngradev]:
                    for para in filter_para_grad(net.parameters()):
                        para.grad = None
            else:
                for net in self.nets[1:self.ngradev]:
                    for para in filter_para_grad(func(net).parameters()):
                        para.grad = None
Beispiel #5
0
	def update_replicas_para(self, parallel=False):

		params = [para.data for para in filter_para_grad(self.module.parameters())]

		if len(params) > 0:
			param_copies = tuple([t for tensors in comm.broadcast_coalesced(params, self.device_ids) for t in tensors])

			#for module, param_copy in zip(self.nets, [param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params))]):
			for module, param_copy in zip(self.nets[1:], [param_copies[i:i + len(params)] for i in range(len(params), len(param_copies), len(params))]):
				for mp, para in zip(filter_para_grad(module.parameters()), param_copy):
					mp.data = para
		self.ngradev = 0
Beispiel #6
0
	def update_replicas(self, parallel=False):

		params = [para.data for para in filter_para_grad(self.module.parameters())]

		if len(params) > 0:
			param_copies = tuple([t for tensors in comm.broadcast_coalesced(params, self.device_ids) for t in tensors])

			# currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient
			#for module, param_copy in zip(self.nets, [param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params))]):
			for module, param_copy in zip(self.nets[1:], [param_copies[i:i + len(params)] for i in range(len(params), len(param_copies), len(params))]):
				for mp, para in zip(filter_para_grad(module.parameters()), param_copy):
					mp.data, mp.grad = para, None

		self.ngradev = 0
Beispiel #7
0
	def allocate(self, parameters=None, init_tensors=None):

		self.pll = self.pll if parameters is None else [filter_para_grad(pl) for pl in parameters]
		cpl = []
		if init_tensors is None:
			for pl in self.pll:
				if len(pl) > 1:
					_numel = sum(para.numel() for para in pl)
					_weight = nn.Parameter(pl[0].new_empty(_numel))
					_weight.grad = pl[0].new_zeros(_numel)
					cpl.append(_weight)
				else:
					_weight = pl[0]
					if _weight.grad is None:
						_weight.grad = _weight.new_zeros(_weight.size())
					cpl.append(_weight)
		else:
			for pl, init_tensor in zip(self.pll, init_tensors):
				if len(pl) > 1:
					_numel = sum(para.numel() for para in pl) if init_tensor is None else init_tensor.numel()
					_weight = nn.Parameter(pl[0].new_empty(_numel) if init_tensor is None else init_tensor)
					_weight.grad = pl[0].new_zeros(init_tensor.numel()) if (init_tensor is None) or (init_tensor.grad is None) else init_tensor.grad
					cpl.append(_weight)
				else:
					_weight = pl[0]
					if _weight.grad is None:
						_weight.grad = _weight.new_zeros(_weight.size()) if (init_tensor is None) or (init_tensor.grad is None) else init_tensor.grad.view(_weight.size())
					cpl.append(_weight)
		self.weights = nn.ParameterList(cpl)
Beispiel #8
0
    def zero_grad(self, set_to_none=True):

        if self.is_contiguous_parameters:
            with torch.no_grad():
                for para in get_all_contiguous_parameters_m(self.module):
                    para.grad.zero_()
                if self.nets is not None and self.ngradev > 1:
                    for net in self.nets[1:self.ngradev]:
                        for para in get_all_contiguous_parameters_m(net):
                            para.grad.zero_()
        else:
            for para in filter_para_grad(self.module.parameters()):
                para.grad = None
            if self.nets is not None and self.ngradev > 1:
                for net in self.nets[1:self.ngradev]:
                    for para in filter_para_grad(net.parameters()):
                        para.grad = None
        self.ngradev = 0
Beispiel #9
0
    def build_optimizer(self,
                        optm_func,
                        *optm_args,
                        multi_gpu_optimizer=False,
                        contiguous_parameters=False,
                        **optm_kwargs):

        self.is_contiguous_parameters = contiguous_parameters
        paras = filter_para_grad(self.module.parameters())
        if (not multi_gpu_optimizer) or self.nets is None or (len(paras) < 2):
            if contiguous_parameters:
                if self.nets is not None:
                    for net in self.nets[1:]:
                        get_contiguous_parameters_m(net)
                _mp = get_contiguous_parameters_m(self.module)
            else:
                _mp = self.module.parameters()
            return optm_func(_mp, *optm_args, **optm_kwargs)
        else:
            self.optm_splt, _np = divide_para_ind(paras,
                                                  len(self.device_ids),
                                                  return_np=True)
            if contiguous_parameters:
                for net in self.nets:
                    get_contiguous_parameters_p([
                        list(
                            range_parameter_iter(
                                net, lind, rind, func=filter_para_grad_iter))
                        for lind, rind in self.optm_splt
                    ],
                                                model=net)
                optml = [
                    optm_func(get_contiguous_parameters_m(net, index=i),
                              *optm_args, **optm_kwargs)
                    for i, net in enumerate(self.nets)
                ]
            else:
                optml = [
                    optm_func(
                        range_parameter_iter(net,
                                             lind,
                                             rind,
                                             func=filter_para_grad_iter),
                        *optm_args, **optm_kwargs) for net, (
                            lind,
                            rind,
                        ) in zip(self.nets, self.optm_splt)
                ]
            # sort the optimizers with slightly more parameters ahead to start their optimization steps earlier
            optml, _device_ids = reorder_by_sort(_np,
                                                 optml,
                                                 self.device_ids[:len(optml)],
                                                 reverse=True)
            return MultiGPUOptimizer(optml, device_ids=_device_ids)
Beispiel #10
0
    def zero_replicas_grad(self, func=None):

        if self.nets is not None and self.ngradev > 1:
            if func is None:
                if self.is_contiguous_parameters:
                    for net in self.nets[1:self.ngradev]:
                        for para in get_all_contiguous_parameters_m(net):
                            para.grad.zero_()
                else:
                    for net in self.nets[1:self.ngradev]:
                        for para in filter_para_grad(net.parameters()):
                            para.grad = None
            else:
                if self.is_contiguous_parameters:
                    for net in self.nets[1:self.ngradev]:
                        for para in filter_para_grad(func(net).parameters()):
                            para.grad.zero_()
                else:
                    for net in self.nets[1:self.ngradev]:
                        for para in filter_para_grad(func(net).parameters()):
                            para.grad = None
Beispiel #11
0
    def zero_replicas_grad(self):

        if self.nets is not None and self.ngradev > 1:
            for net in self.nets[1:self.ngradev]:
                for para in filter_para_grad(net.parameters()):
                    para.grad = None
Beispiel #12
0
    def update_replicas_para(self):

        if self.optm_splt is None:
            if self.is_contiguous_parameters:
                params = [
                    para.data
                    for para in get_all_contiguous_parameters_m(self.module)
                ]
                param_copies = comm.broadcast_coalesced(
                    params, self.device_ids)
                with torch.no_grad():
                    for module, param_copy in zip(self.nets[1:],
                                                  param_copies[1:]):
                        for mp, para in zip(
                                get_all_contiguous_parameters_m(module),
                                param_copy):
                            mp.data.copy_(para)
            else:
                params = [
                    para.data
                    for para in filter_para_grad(self.module.parameters())
                ]
                param_copies = comm.broadcast_coalesced(
                    params, self.device_ids)
                for module, param_copy in zip(self.nets[1:], param_copies[1:]):
                    for mp, para in zip(filter_para_grad(module.parameters()),
                                        param_copy):
                        mp.data = para
        else:
            for i, (
                    net,
                (
                    lind,
                    rind,
                ),
            ) in enumerate(zip(self.nets, self.optm_splt)):
                _dev_params = [
                    para.data
                    for para in get_contiguous_parameters_m(net, index=i)
                ] if self.is_contiguous_parameters else [
                    para.data for para in range_parameter_iter(
                        net, lind, rind, func=filter_para_grad_iter)
                ]
                if i > 0:
                    _devices = self.device_ids[:]
                    _devices.insert(0, _devices.pop(i))
                else:
                    _devices = self.device_ids
                _dev_param_copies = comm.broadcast_coalesced(
                    _dev_params, _devices)
                if i > 0:
                    _dev_param_copies.insert(i, _dev_param_copies.pop(0))
                    for pc, _dpc in zip(param_copies, _dev_param_copies):
                        pc.extend(_dpc)
                else:
                    param_copies = _dev_param_copies
                if self.is_contiguous_parameters:
                    with torch.no_grad():
                        for module, param_copy in zip(self.nets, param_copies):
                            for mp, para in zip(
                                    get_all_contiguous_parameters_m(module),
                                    param_copy):
                                mp.data.copy_(para)
                else:
                    for module, param_copy in zip(self.nets, param_copies):
                        for mp, para in zip(
                                filter_para_grad(module.parameters()),
                                param_copy):
                            mp.data = para

        self.ngradev = 0
Beispiel #13
0
    def update_replicas(self):

        if self.optm_splt is None:
            if self.is_contiguous_parameters:
                params = [
                    para.data
                    for para in get_all_contiguous_parameters_m(self.module)
                ]
                param_copies = comm.broadcast_coalesced(
                    params, self.device_ids)
                with torch.no_grad():
                    for module, param_copy in zip(self.nets[1:],
                                                  param_copies[1:]):
                        for mp, para in zip(
                                get_all_contiguous_parameters_m(module),
                                param_copy):
                            mp.data.copy_(para)
                            mp.grad.zero_()
            else:
                params = [
                    para.data
                    for para in filter_para_grad(self.module.parameters())
                ]
                param_copies = comm.broadcast_coalesced(
                    params, self.device_ids)
                # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient
                #for module, param_copy in zip(self.nets, param_copies):
                for module, param_copy in zip(self.nets[1:], param_copies[1:]):
                    for mp, para in zip(filter_para_grad(module.parameters()),
                                        param_copy):
                        mp.data, mp.grad = para, None
        else:
            for i, (
                    net,
                (
                    lind,
                    rind,
                ),
            ) in enumerate(zip(self.nets, self.optm_splt)):
                _dev_params = [
                    para.data
                    for para in get_contiguous_parameters_m(net, index=i)
                ] if self.is_contiguous_parameters else [
                    para.data for para in range_parameter_iter(
                        net, lind, rind, func=filter_para_grad_iter)
                ]
                if i > 0:
                    _devices = self.device_ids[:]
                    _devices.insert(0, _devices.pop(i))
                else:
                    _devices = self.device_ids
                _dev_param_copies = comm.broadcast_coalesced(
                    _dev_params, _devices)
                if i > 0:
                    _dev_param_copies.insert(i, _dev_param_copies.pop(0))
                    for pc, _dpc in zip(param_copies, _dev_param_copies):
                        pc.extend(_dpc)
                else:
                    param_copies = _dev_param_copies
            if self.is_contiguous_parameters:
                with torch.no_grad():
                    for module, param_copy in zip(self.nets, param_copies):
                        for mp, para in zip(
                                get_all_contiguous_parameters_m(module),
                                param_copy):
                            mp.data.copy_(para)
                            mp.grad.zero_()
            else:
                for module, param_copy in zip(self.nets, param_copies):
                    for mp, para in zip(filter_para_grad(module.parameters()),
                                        param_copy):
                        mp.data, mp.grad = para, None

        self.ngradev = 0
Beispiel #14
0
    def collect_gradients(self):

        if self.optm_splt is not None:
            if self.is_contiguous_parameters:
                for i, (
                        net,
                        device,
                ) in enumerate(zip(self.nets, self.device_ids)):
                    _dev_grads = [[
                        para.grad
                        for para in get_contiguous_parameters_m(_net, index=i)
                    ] for _net in self.nets[:self.ngradev]]
                    if i > 0:
                        _dev_grads.insert(
                            0,
                            _dev_grads.pop(i) if i < self.ngradev else [
                                para.grad for para in
                                get_contiguous_parameters_m(net, index=i)
                            ])
                    _dev_grads = comm.reduce_add_coalesced(_dev_grads, device)
                    for mp, grad in zip(
                            get_contiguous_parameters_m(net, index=i),
                            _dev_grads):
                        mp.grad.copy_(grad)
            else:
                grads = [[
                    para.grad for para in filter_para_grad(net.parameters())
                ] for net in self.nets[:self.ngradev]]
                for i, (
                        net,
                        device,
                    (
                        lind,
                        rind,
                    ),
                ) in enumerate(zip(self.nets, self.device_ids,
                                   self.optm_splt)):
                    _dev_grads = [gradu[lind:rind] for gradu in grads]
                    if i > 0:
                        _dev_grads.insert(
                            0,
                            _dev_grads.pop(i) if i < self.ngradev else [
                                _pg.new_zeros(_pg.size(), device=device)
                                for _pg in _dev_grads[0]
                            ])
                    _dev_grads = comm.reduce_add_coalesced(_dev_grads, device)
                    for mp, grad in zip(
                            range_parameter_iter(net,
                                                 lind,
                                                 rind,
                                                 func=filter_para_grad_iter),
                            _dev_grads):
                        mp.grad = grad
        elif self.ngradev > 1:
            if self.is_contiguous_parameters:
                grads = comm.reduce_add_coalesced([[
                    para.grad for para in get_all_contiguous_parameters_m(net)
                ] for net in self.nets[:self.ngradev]], self.output_device)
                for mp, grad in zip(
                        get_all_contiguous_parameters_m(self.module), grads):
                    mp.grad.copy_(grad)
            else:
                # in case some parameters might not be used during the forward propagation on some GPUs: p.data.new_zeros(p.data.size()) if p.grad is None else p.grad instead of p.grad, but in most cases, this can warn you in case you miss the use of some parameters in the forward computation.
                grads = comm.reduce_add_coalesced(
                    [[
                        para.grad
                        for para in filter_para_grad(net.parameters())
                    ] for net in self.nets[:self.ngradev]], self.output_device
                )  # if self.ngradev > 1 else [p.grad for p in filter_para_grad(self.nets[0].parameters())]
                for mp, grad in zip(filter_para_grad(self.module.parameters()),
                                    grads):
                    mp.grad = grad