Exemplo n.º 1
0
    def update(self, param, grad):
        """Performs a single optimization step.

        Arguments:
                param(Tensor): param values to be update in-place
                grad(Tensor): param gradients; the values may be updated
                        in this function; cannot use it anymore
        """
        group = self.default_config
        if param in self.param2config:
            group = self.param2config[param]
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        nesterov = group['nesterov']

        if weight_decay != 0:
            grad += param * weight_decay
        if momentum != 0:
            if param not in self.param2state:
                self.param2state[param] = {}
            param_state = self.param2state[param]
            if 'momentum_buffer' not in param_state:
                buf = param_state['momentum_buffer'] = tensor.zeros_like(param)
                buf *= momentum
                buf += grad
            else:
                buf = param_state['momentum_buffer']
                buf *= momentum
                buf += (1 - dampening) * grad
            if nesterov:
                grad += momentum * buf
            else:
                grad = buf
        param -= grad * group['lr']
Exemplo n.º 2
0
    def apply(self, param_name, param_value, param_grad):
        """Performs a single optimization step.

        Args:
                param_name(String): the name of the param
                param_value(Tensor): param values to be update in-place
                grad(Tensor): param gradients; the values may be updated
                        in this function; cannot use it anymore
        """
        assert param_value.shape == param_grad.shape, ("shape mismatch",
                                                       param_value.shape,
                                                       param_grad.shape)
        self.device_check(param_value, self.step_counter, self.lr_value,
                          self.mom_value, self.dam_value, self.decay_value)

        # derive dtype from input
        assert param_value.dtype == self.dtype

        # TODO add branch operator
        # if self.decay_value != 0:
        if self.weight_decay.init_value != 0:
            singa.Axpy(self.decay_value.data, param_value.data,
                       param_grad.data)

        if self.momentum.init_value != 0:
            if param_name not in self.moments:
                flag = param_value.device.graph_enabled()
                param_value.device.EnableGraph(False)
                self.moments[param_name] = tensor.zeros_like(param_value)
                param_value.device.EnableGraph(flag)

            buf = self.moments[param_name]
            buf *= self.mom_value
            alpha = 1.0 - self.dam_value
            singa.Axpy(alpha.data, param_grad.data, buf.data)

            if self.nesterov:
                singa.Axpy(self.mom_value.data, buf.data, param_grad.data)
            else:
                param_grad = buf

        minus_lr = 0.0 - self.lr_value
        singa.Axpy(minus_lr.data, param_grad.data, param_value.data)
Exemplo n.º 3
0
Arquivo: opt.py Projeto: zxr8192/singa
    def update(self, param, grad):
        """Performs a single optimization step.

        Args:
                param(Tensor): param values to be update in-place
                grad(Tensor): param gradients; the values may be updated
                        in this function; cannot use it anymore
        """
        assert param.shape == grad.shape, ("shape mismatch", param.shape,
                                           grad.shape)
        group = self.default_config
        if param in self.param2config:
            group = self.param2config[param]
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        nesterov = group['nesterov']

        if weight_decay != 0:
            singa.Axpy(weight_decay, param.data, grad.data)
        if momentum != 0:
            if param not in self.param2state:
                self.param2state[param] = {}
            param_state = self.param2state[param]
            if 'momentum_buffer' not in param_state:
                flag = param.device.graph_enabled()
                param.device.EnableGraph(False)
                buf = param_state['momentum_buffer'] = tensor.zeros_like(param)
                param.device.EnableGraph(flag)

                buf *= momentum
                singa.Axpy(1.0, grad.data, buf.data)
            else:
                buf = param_state['momentum_buffer']
                buf *= momentum
                singa.Axpy(1.0 - dampening, grad.data, buf.data)
            if nesterov:
                singa.Axpy(momentum, buf.data, grad.data)
            else:
                grad = buf
        singa.Axpy(-group['lr'], grad.data, param.data)
Exemplo n.º 4
0
    def apply(self, param_name, param_value, param_grad):
        """Performs a single optimization step.

        Args:
                param_name(String): the name of the param
                param_value(Tensor): param values to be update in-place
                grad(Tensor): param gradients; the values may be updated
                        in this function; cannot use it anymore
        """
        assert param_value.shape == param_grad.shape, ("shape mismatch",
                                                       param_value.shape,
                                                       param_grad.shape)
        self.device_check(param_value, self.step_counter, self.lr_value,
                          self.rho_value, self.epsilon_value, self.decay_value)

        # if self.decay_value != 0:
        if self.weight_decay.init_value != 0:
            singa.Axpy(self.decay_value.data, param_value.data,
                       param_grad.data)

        if param_name not in self.running_average:
            flag = param_value.device.graph_enabled()
            param_value.device.EnableGraph(False)
            self.running_average[param_name] = tensor.zeros_like(param_value)
            param_value.device.EnableGraph(flag)

        # running_average = running_average * rho + param_grad * param_grad * (1 - rho)
        # param_value = param_value - lr * param_grad / sqrt(running_average + epsilon)

        self.running_average[param_name] *= self.rho_value

        tmp1 = singa.Square(param_grad.data)
        tmp2 = 1.0 - self.rho_value
        singa.Axpy(tmp2.data, tmp1, self.running_average[param_name].data)

        minus_lr = 0.0 - self.lr_value
        tmp3 = self.running_average[param_name] + self.epsilon_value
        tmp3 = singa.Sqrt(tmp3.data)
        tmp3 = singa.__div__(param_grad.data, tmp3)

        singa.Axpy(minus_lr.data, tmp3, param_value.data)
Exemplo n.º 5
0
    def update(self, param, grad):
        """Performs a single optimization step.

        Arguments:
                param(Tensor): param values to be update in-place
                grad(Tensor): param gradients; the values may be updated
                        in this function; cannot use it anymore
        """
        assert param.shape == grad.shape, ("shape mismatch", param.shape, grad.shape)
        group = self.default_config
        if param in self.param2config:
            group = self.param2config[param]
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        nesterov = group['nesterov']

        if weight_decay != 0:
            grad += param * weight_decay
        if momentum != 0:
            if param not in self.param2state:
                self.param2state[param] = {}
            param_state = self.param2state[param]
            if 'momentum_buffer' not in param_state:
                buf = param_state[
                    'momentum_buffer'] = tensor.zeros_like(param)
                buf *= momentum
                buf += grad
            else:
                buf = param_state['momentum_buffer']
                buf *= momentum
                buf += (1 - dampening) * grad
            if nesterov:
                grad += momentum * buf
            else:
                grad = buf
        param -= grad * group['lr']
Exemplo n.º 6
0
    def apply(self, param_name, param_value, param_grad):
        """Performs a single optimization step.

        Args:
                param_name(String): the name of the param
                param_value(Tensor): param values to be update in-place
                grad(Tensor): param gradients; the values may be updated
                        in this function; cannot use it anymore
        """
        assert param_value.shape == param_grad.shape, ("shape mismatch",
                                                       param_value.shape,
                                                       param_grad.shape)
        self.device_check(param_value, self.step_counter, self.lr_value,
                          self.beta_1_value, self.beta_2_value,
                          self.epsilon_value, self.decay_value)

        # if self.decay_value != 0:
        if self.weight_decay.init_value != 0:
            singa.Axpy(self.decay_value.data, param_value.data,
                       param_grad.data)

        if param_name not in self.m:
            flag = param_value.device.graph_enabled()
            param_value.device.EnableGraph(False)
            self.m[param_name] = tensor.zeros_like(param_value)
            self.v[param_name] = tensor.zeros_like(param_value)
            param_value.device.EnableGraph(flag)

        # overall steps
        # m := beta_1 * m + (1 - beta_1) * grad
        # v := beta_2 * v + (1 - beta_2) * grad * grad
        # m_norm = m / (1 - beta_1 ^ step)
        # v_norm = v / (1 - beta_2 ^ step)
        # param := param - (lr * m_norm) / ( sqrt(v_norm) + epsilon) )

        step = self.step_counter + 1.0

        # m := beta_1 * m + (1 - beta_1) * grad
        tmp = 1.0 - self.beta_1_value
        self.m[param_name] *= self.beta_1_value
        singa.Axpy(tmp.data, param_grad.data, self.m[param_name].data)

        # v := beta_2 * v + (1 - beta_2) * grad * grad
        tmp = 1.0 - self.beta_2_value
        self.v[param_name] *= self.beta_2_value
        singa.Axpy(tmp.data, singa.Square(param_grad.data),
                   self.v[param_name].data)

        # m_norm = m / (1 - beta_1 ^ step)
        tmp = tensor.pow(self.beta_1_value, step)
        tmp = 1.0 - tmp
        m_norm = self.m[param_name] / tmp

        # v_norm = v / (1 - beta_2 ^ step)
        tmp = tensor.pow(self.beta_2_value, step)
        tmp = 1.0 - tmp
        v_norm = self.v[param_name] / tmp

        # param := param - (lr * m_norm) / ( sqrt(v_norm) + epsilon) )
        a = tensor.sqrt(v_norm) + self.epsilon_value
        tmp = m_norm / a

        minus_lr = 0.0 - self.lr_value
        singa.Axpy(minus_lr.data, tmp.data, param_value.data)