def update(self, index, weight, grad, state): """Update method.""" try: from mxnet.ndarray.contrib import adamw_update except ImportError: raise ImportError( "Failed to import nd.contrib.adamw_update from MXNet. " "BERTAdam optimizer requires mxnet>=1.5.0b20181228. " "Please upgrade your MXNet version.") assert (isinstance(weight, NDArray)) assert (isinstance(grad, NDArray)) self._update_count(index) lr = self._get_lr(index) wd = self._get_wd(index) kwargs = { 'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 'rescale_grad': self.rescale_grad } if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient mean, var = state adamw_update(weight, grad, mean, var, out=weight, lr=1, wd=wd, eta=lr, **kwargs)
def _update_impl(self, indices, weight, grad, state, multi_precision=False): """update function""" try: from mxnet.ndarray.contrib import adamw_update, mp_adamw_update except ImportError: raise ImportError( 'Failed to import nd.contrib.adamw_update and ' 'nd.contrib.mp_adamw_update from MXNet. ' 'BERTAdam optimizer requires mxnet>=1.5.0b20190220. ' 'Please upgrade your MXNet version. For example: ' 'pip uninstall mxnet-cu90 -y; pip install mxnet-cu90 --pre') self._update_count(indices) lr = self._get_lr(indices) wd = self._get_wd(indices) # pylint: disable=access-member-before-definition if not isinstance(self.rescale_grad, NDArray): self.rescale_grad = full(shape=(1, ), val=self.rescale_grad, ctx=weight.context) else: self.rescale_grad = self.rescale_grad.as_in_context(weight.context) kwargs = { 'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 'rescale_grad': self.rescale_grad } if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient if not multi_precision: mean, var = state adamw_update(weight, grad, mean, var, out=weight, lr=1, wd=wd, eta=lr, **kwargs) else: mean, var = state[0] mp_adamw_update(weight, grad, mean, var, state[1], out=weight, lr=1, wd=wd, eta=lr, **kwargs)
def _update_impl(self, indices, weight, grad, state, multi_precision=False): """update function""" self._update_count(indices) lr = self._get_lr(indices) wd = self._get_wd(indices) # pylint: disable=access-member-before-definition if not isinstance(self.rescale_grad, NDArray): self.rescale_grad = full(shape=(1,), val=self.rescale_grad, ctx=weight.context) else: self.rescale_grad = self.rescale_grad.as_in_context(weight.context) kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 'rescale_grad': self.rescale_grad} if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient if not multi_precision: mean, var = state adamw_update(weight, grad, mean, var, out=weight, lr=1, wd=wd, eta=lr, **kwargs) else: mean, var = state[0] mp_adamw_update(weight, grad, mean, var, state[1], out=weight, lr=1, wd=wd, eta=lr, **kwargs)
def _update_impl(self, indices, weight, grad, state, multi_precision=False): """update function""" aggregate = self.aggregate_num > 1 if not isinstance(indices, (tuple, list)): indices = [indices] weight = [weight] grad = [grad] state = [state] for w_i, g_i in zip(weight, grad): assert (isinstance(w_i, NDArray)) assert (isinstance(g_i, NDArray)) aggregate = (aggregate and w_i.stype == 'default' and g_i.stype == 'default') self._update_count(indices) lrs = self._get_lrs(indices) wds = self._get_wds(indices) # pylint: disable=access-member-before-definition if not isinstance(self.rescale_grad, NDArray): self.rescale_grad = full(shape=(1, ), val=self.rescale_grad, ctx=weight[0].context) else: self.rescale_grad = self.rescale_grad.as_in_context( weight[0].context) kwargs = { 'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 'rescale_grad': self.rescale_grad } if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient if aggregate: current_index = 0 while current_index < len(indices): sidx = current_index eidx = min(current_index + self.aggregate_num, len(indices)) if not multi_precision: mean, var = list(zip(*state[sidx:eidx])) multi_adamw_update(weight[sidx:eidx], grad[sidx:eidx], mean, var, out=weight[sidx:eidx], size=len(weight[sidx:eidx]), lrs=list( numpy.ones(len(weight[sidx:eidx]))), wds=wds[sidx:eidx], etas=lrs[sidx:eidx], **kwargs) else: mean_var = list(zip(*state[sidx:eidx]))[0] tmean_var = list(zip(*mean_var)) mean = tmean_var[0] var = tmean_var[1] multi_mp_adamw_update( weight[sidx:eidx], grad[sidx:eidx], mean, var, list(zip(*state[sidx:eidx]))[1], out=weight[sidx:eidx], size=len(weight[sidx:eidx]), lrs=list(numpy.ones(len(weight[sidx:eidx]))), wds=wds[sidx:eidx], etas=lrs[sidx:eidx], **kwargs) current_index += self.aggregate_num else: for w_i, g_i, s_i, lr, wd in zip(weight, grad, state, lrs, wds): if not multi_precision: mean, var = s_i adamw_update(w_i, g_i, mean, var, out=w_i, lr=1, wd=wd, eta=lr, **kwargs) else: mean, var = s_i[0] mp_adamw_update(w_i, g_i, mean, var, s_i[1], out=w_i, lr=1, wd=wd, eta=lr, **kwargs)
def fused_step(self, indices, weights, grads, states): """Perform a fused optimization step using gradients and states. Fused kernel is used for update. Parameters ---------- indices : list of int List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` and `set_wd_mult()`, respectively. weights : list of NDArray List of parameters to be updated. grads : list of NDArray List of gradients of the objective with respect to this parameter. states : List of any obj List of state returned by `create_state()`. """ multi_precision = self.multi_precision and weights[ 0].dtype == np.float16 aggregate = self.aggregate_num > 1 if not isinstance(indices, (tuple, list)): indices = [indices] weights = [weights] grads = [grads] states = [states] for w_i, g_i in zip(weights, grads): assert (isinstance(w_i, mx.nd.NDArray)) assert (isinstance(g_i, mx.nd.NDArray)) aggregate = (aggregate and w_i.stype == 'default' and g_i.stype == 'default') self._update_count(indices) lrs = self._get_lrs(indices) wds = self._get_wds(indices) if self.correct_bias: new_lrs = [] for idx, lr in zip(indices, lrs): t = self._index_update_count[idx] coef1 = 1. - self.beta1**t coef2 = 1. - self.beta2**t new_lrs.append(lr * math.sqrt(coef2) / coef1) lrs = new_lrs if not isinstance(self.rescale_grad, mx.nd.NDArray): self.rescale_grad = mx.nd.full(shape=(1, ), val=self.rescale_grad, ctx=weights[0].context) else: self.rescale_grad = self.rescale_grad.as_in_context( weights[0].context) kwargs = { 'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 'rescale_grad': self.rescale_grad } if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient if aggregate: current_index = 0 while current_index < len(indices): sidx = current_index eidx = min(current_index + self.aggregate_num, len(indices)) if not multi_precision: mean, var = list(zip(*states[sidx:eidx])) multi_adamw_update(weights[sidx:eidx], grads[sidx:eidx], mean, var, out=weights[sidx:eidx], size=len(weights[sidx:eidx]), lrs=list( np.ones(len(weights[sidx:eidx]))), wds=wds[sidx:eidx], etas=lrs[sidx:eidx], **kwargs) else: mean_var = list(zip(*states[sidx:eidx]))[0] tmean_var = list(zip(*mean_var)) mean = tmean_var[0] var = tmean_var[1] multi_mp_adamw_update( weights[sidx:eidx], grads[sidx:eidx], mean, var, list(zip(*states[sidx:eidx]))[1], out=weights[sidx:eidx], size=len(weights[sidx:eidx]), lrs=list(np.ones(len(weights[sidx:eidx]))), wds=wds[sidx:eidx], etas=lrs[sidx:eidx], **kwargs) current_index += self.aggregate_num else: for w_i, g_i, s_i, lr, wd in zip(weights, grads, states, lrs, wds): if not multi_precision: mean, var = s_i adamw_update(w_i, g_i, mean, var, out=w_i, lr=1, wd=wd, eta=lr, **kwargs) else: mean, var = s_i[0] mp_adamw_update(w_i, g_i, mean, var, s_i[1], out=w_i, lr=1, wd=wd, eta=lr, **kwargs)