Example #1
0
    def update(self, index, weight, grad, state):
        """Update method."""
        try:
            from mxnet.ndarray.contrib import adamw_update
        except ImportError:
            raise ImportError(
                "Failed to import nd.contrib.adamw_update from MXNet. "
                "BERTAdam optimizer requires mxnet>=1.5.0b20181228. "
                "Please upgrade your MXNet version.")
        assert (isinstance(weight, NDArray))
        assert (isinstance(grad, NDArray))
        self._update_count(index)
        lr = self._get_lr(index)
        wd = self._get_wd(index)

        kwargs = {
            'beta1': self.beta1,
            'beta2': self.beta2,
            'epsilon': self.epsilon,
            'rescale_grad': self.rescale_grad
        }
        if self.clip_gradient:
            kwargs['clip_gradient'] = self.clip_gradient

        mean, var = state
        adamw_update(weight,
                     grad,
                     mean,
                     var,
                     out=weight,
                     lr=1,
                     wd=wd,
                     eta=lr,
                     **kwargs)
Example #2
0
    def _update_impl(self,
                     indices,
                     weight,
                     grad,
                     state,
                     multi_precision=False):
        """update function"""
        try:
            from mxnet.ndarray.contrib import adamw_update, mp_adamw_update
        except ImportError:
            raise ImportError(
                'Failed to import nd.contrib.adamw_update and '
                'nd.contrib.mp_adamw_update from MXNet. '
                'BERTAdam optimizer requires mxnet>=1.5.0b20190220. '
                'Please upgrade your MXNet version. For example: '
                'pip uninstall mxnet-cu90 -y; pip install mxnet-cu90 --pre')
        self._update_count(indices)
        lr = self._get_lr(indices)
        wd = self._get_wd(indices)

        # pylint: disable=access-member-before-definition
        if not isinstance(self.rescale_grad, NDArray):
            self.rescale_grad = full(shape=(1, ),
                                     val=self.rescale_grad,
                                     ctx=weight.context)
        else:
            self.rescale_grad = self.rescale_grad.as_in_context(weight.context)

        kwargs = {
            'beta1': self.beta1,
            'beta2': self.beta2,
            'epsilon': self.epsilon,
            'rescale_grad': self.rescale_grad
        }
        if self.clip_gradient:
            kwargs['clip_gradient'] = self.clip_gradient
        if not multi_precision:
            mean, var = state
            adamw_update(weight,
                         grad,
                         mean,
                         var,
                         out=weight,
                         lr=1,
                         wd=wd,
                         eta=lr,
                         **kwargs)
        else:
            mean, var = state[0]
            mp_adamw_update(weight,
                            grad,
                            mean,
                            var,
                            state[1],
                            out=weight,
                            lr=1,
                            wd=wd,
                            eta=lr,
                            **kwargs)
Example #3
0
    def _update_impl(self, indices, weight, grad, state, multi_precision=False):
        """update function"""
        self._update_count(indices)
        lr = self._get_lr(indices)
        wd = self._get_wd(indices)

        # pylint: disable=access-member-before-definition
        if not isinstance(self.rescale_grad, NDArray):
            self.rescale_grad = full(shape=(1,), val=self.rescale_grad, ctx=weight.context)
        else:
            self.rescale_grad = self.rescale_grad.as_in_context(weight.context)

        kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
                  'rescale_grad': self.rescale_grad}
        if self.clip_gradient:
            kwargs['clip_gradient'] = self.clip_gradient
        if not multi_precision:
            mean, var = state
            adamw_update(weight, grad, mean, var, out=weight,
                         lr=1, wd=wd, eta=lr, **kwargs)
        else:
            mean, var = state[0]
            mp_adamw_update(weight, grad, mean, var, state[1], out=weight,
                            lr=1, wd=wd, eta=lr, **kwargs)
Example #4
0
    def _update_impl(self,
                     indices,
                     weight,
                     grad,
                     state,
                     multi_precision=False):
        """update function"""
        aggregate = self.aggregate_num > 1
        if not isinstance(indices, (tuple, list)):
            indices = [indices]
            weight = [weight]
            grad = [grad]
            state = [state]
        for w_i, g_i in zip(weight, grad):
            assert (isinstance(w_i, NDArray))
            assert (isinstance(g_i, NDArray))
            aggregate = (aggregate and w_i.stype == 'default'
                         and g_i.stype == 'default')
        self._update_count(indices)
        lrs = self._get_lrs(indices)
        wds = self._get_wds(indices)

        # pylint: disable=access-member-before-definition
        if not isinstance(self.rescale_grad, NDArray):
            self.rescale_grad = full(shape=(1, ),
                                     val=self.rescale_grad,
                                     ctx=weight[0].context)
        else:
            self.rescale_grad = self.rescale_grad.as_in_context(
                weight[0].context)

        kwargs = {
            'beta1': self.beta1,
            'beta2': self.beta2,
            'epsilon': self.epsilon,
            'rescale_grad': self.rescale_grad
        }
        if self.clip_gradient:
            kwargs['clip_gradient'] = self.clip_gradient

        if aggregate:
            current_index = 0
            while current_index < len(indices):
                sidx = current_index
                eidx = min(current_index + self.aggregate_num, len(indices))
                if not multi_precision:
                    mean, var = list(zip(*state[sidx:eidx]))
                    multi_adamw_update(weight[sidx:eidx],
                                       grad[sidx:eidx],
                                       mean,
                                       var,
                                       out=weight[sidx:eidx],
                                       size=len(weight[sidx:eidx]),
                                       lrs=list(
                                           numpy.ones(len(weight[sidx:eidx]))),
                                       wds=wds[sidx:eidx],
                                       etas=lrs[sidx:eidx],
                                       **kwargs)
                else:
                    mean_var = list(zip(*state[sidx:eidx]))[0]
                    tmean_var = list(zip(*mean_var))
                    mean = tmean_var[0]
                    var = tmean_var[1]
                    multi_mp_adamw_update(
                        weight[sidx:eidx],
                        grad[sidx:eidx],
                        mean,
                        var,
                        list(zip(*state[sidx:eidx]))[1],
                        out=weight[sidx:eidx],
                        size=len(weight[sidx:eidx]),
                        lrs=list(numpy.ones(len(weight[sidx:eidx]))),
                        wds=wds[sidx:eidx],
                        etas=lrs[sidx:eidx],
                        **kwargs)
                current_index += self.aggregate_num
        else:
            for w_i, g_i, s_i, lr, wd in zip(weight, grad, state, lrs, wds):
                if not multi_precision:
                    mean, var = s_i
                    adamw_update(w_i,
                                 g_i,
                                 mean,
                                 var,
                                 out=w_i,
                                 lr=1,
                                 wd=wd,
                                 eta=lr,
                                 **kwargs)
                else:
                    mean, var = s_i[0]
                    mp_adamw_update(w_i,
                                    g_i,
                                    mean,
                                    var,
                                    s_i[1],
                                    out=w_i,
                                    lr=1,
                                    wd=wd,
                                    eta=lr,
                                    **kwargs)
Example #5
0
    def fused_step(self, indices, weights, grads, states):
        """Perform a fused optimization step using gradients and states.
        Fused kernel is used for update.

        Parameters
        ----------
        indices : list of int
            List of unique indices of the parameters into the individual learning rates
            and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
            and `set_wd_mult()`, respectively.
        weights : list of NDArray
            List of parameters to be updated.
        grads : list of NDArray
            List of gradients of the objective with respect to this parameter.
        states : List of any obj
            List of state returned by `create_state()`.
        """
        multi_precision = self.multi_precision and weights[
            0].dtype == np.float16
        aggregate = self.aggregate_num > 1
        if not isinstance(indices, (tuple, list)):
            indices = [indices]
            weights = [weights]
            grads = [grads]
            states = [states]
        for w_i, g_i in zip(weights, grads):
            assert (isinstance(w_i, mx.nd.NDArray))
            assert (isinstance(g_i, mx.nd.NDArray))
            aggregate = (aggregate and w_i.stype == 'default'
                         and g_i.stype == 'default')
        self._update_count(indices)
        lrs = self._get_lrs(indices)
        wds = self._get_wds(indices)
        if self.correct_bias:
            new_lrs = []
            for idx, lr in zip(indices, lrs):
                t = self._index_update_count[idx]
                coef1 = 1. - self.beta1**t
                coef2 = 1. - self.beta2**t
                new_lrs.append(lr * math.sqrt(coef2) / coef1)
            lrs = new_lrs
        if not isinstance(self.rescale_grad, mx.nd.NDArray):
            self.rescale_grad = mx.nd.full(shape=(1, ),
                                           val=self.rescale_grad,
                                           ctx=weights[0].context)
        else:
            self.rescale_grad = self.rescale_grad.as_in_context(
                weights[0].context)
        kwargs = {
            'beta1': self.beta1,
            'beta2': self.beta2,
            'epsilon': self.epsilon,
            'rescale_grad': self.rescale_grad
        }
        if self.clip_gradient:
            kwargs['clip_gradient'] = self.clip_gradient

        if aggregate:
            current_index = 0
            while current_index < len(indices):
                sidx = current_index
                eidx = min(current_index + self.aggregate_num, len(indices))
                if not multi_precision:
                    mean, var = list(zip(*states[sidx:eidx]))
                    multi_adamw_update(weights[sidx:eidx],
                                       grads[sidx:eidx],
                                       mean,
                                       var,
                                       out=weights[sidx:eidx],
                                       size=len(weights[sidx:eidx]),
                                       lrs=list(
                                           np.ones(len(weights[sidx:eidx]))),
                                       wds=wds[sidx:eidx],
                                       etas=lrs[sidx:eidx],
                                       **kwargs)
                else:
                    mean_var = list(zip(*states[sidx:eidx]))[0]
                    tmean_var = list(zip(*mean_var))
                    mean = tmean_var[0]
                    var = tmean_var[1]
                    multi_mp_adamw_update(
                        weights[sidx:eidx],
                        grads[sidx:eidx],
                        mean,
                        var,
                        list(zip(*states[sidx:eidx]))[1],
                        out=weights[sidx:eidx],
                        size=len(weights[sidx:eidx]),
                        lrs=list(np.ones(len(weights[sidx:eidx]))),
                        wds=wds[sidx:eidx],
                        etas=lrs[sidx:eidx],
                        **kwargs)
                current_index += self.aggregate_num
        else:
            for w_i, g_i, s_i, lr, wd in zip(weights, grads, states, lrs, wds):
                if not multi_precision:
                    mean, var = s_i
                    adamw_update(w_i,
                                 g_i,
                                 mean,
                                 var,
                                 out=w_i,
                                 lr=1,
                                 wd=wd,
                                 eta=lr,
                                 **kwargs)
                else:
                    mean, var = s_i[0]
                    mp_adamw_update(w_i,
                                    g_i,
                                    mean,
                                    var,
                                    s_i[1],
                                    out=w_i,
                                    lr=1,
                                    wd=wd,
                                    eta=lr,
                                    **kwargs)