def update(self, param, grad): """Performs a single optimization step. Arguments: param(Tensor): param values to be update in-place grad(Tensor): param gradients; the values may be updated in this function; cannot use it anymore """ group = self.default_config if param in self.param2config: group = self.param2config[param] weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] if weight_decay != 0: grad += param * weight_decay if momentum != 0: if param not in self.param2state: self.param2state[param] = {} param_state = self.param2state[param] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = tensor.zeros_like(param) buf *= momentum buf += grad else: buf = param_state['momentum_buffer'] buf *= momentum buf += (1 - dampening) * grad if nesterov: grad += momentum * buf else: grad = buf param -= grad * group['lr']
def apply(self, param_name, param_value, param_grad): """Performs a single optimization step. Args: param_name(String): the name of the param param_value(Tensor): param values to be update in-place grad(Tensor): param gradients; the values may be updated in this function; cannot use it anymore """ assert param_value.shape == param_grad.shape, ("shape mismatch", param_value.shape, param_grad.shape) self.device_check(param_value, self.step_counter, self.lr_value, self.mom_value, self.dam_value, self.decay_value) # derive dtype from input assert param_value.dtype == self.dtype # TODO add branch operator # if self.decay_value != 0: if self.weight_decay.init_value != 0: singa.Axpy(self.decay_value.data, param_value.data, param_grad.data) if self.momentum.init_value != 0: if param_name not in self.moments: flag = param_value.device.graph_enabled() param_value.device.EnableGraph(False) self.moments[param_name] = tensor.zeros_like(param_value) param_value.device.EnableGraph(flag) buf = self.moments[param_name] buf *= self.mom_value alpha = 1.0 - self.dam_value singa.Axpy(alpha.data, param_grad.data, buf.data) if self.nesterov: singa.Axpy(self.mom_value.data, buf.data, param_grad.data) else: param_grad = buf minus_lr = 0.0 - self.lr_value singa.Axpy(minus_lr.data, param_grad.data, param_value.data)
def update(self, param, grad): """Performs a single optimization step. Args: param(Tensor): param values to be update in-place grad(Tensor): param gradients; the values may be updated in this function; cannot use it anymore """ assert param.shape == grad.shape, ("shape mismatch", param.shape, grad.shape) group = self.default_config if param in self.param2config: group = self.param2config[param] weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] if weight_decay != 0: singa.Axpy(weight_decay, param.data, grad.data) if momentum != 0: if param not in self.param2state: self.param2state[param] = {} param_state = self.param2state[param] if 'momentum_buffer' not in param_state: flag = param.device.graph_enabled() param.device.EnableGraph(False) buf = param_state['momentum_buffer'] = tensor.zeros_like(param) param.device.EnableGraph(flag) buf *= momentum singa.Axpy(1.0, grad.data, buf.data) else: buf = param_state['momentum_buffer'] buf *= momentum singa.Axpy(1.0 - dampening, grad.data, buf.data) if nesterov: singa.Axpy(momentum, buf.data, grad.data) else: grad = buf singa.Axpy(-group['lr'], grad.data, param.data)
def apply(self, param_name, param_value, param_grad): """Performs a single optimization step. Args: param_name(String): the name of the param param_value(Tensor): param values to be update in-place grad(Tensor): param gradients; the values may be updated in this function; cannot use it anymore """ assert param_value.shape == param_grad.shape, ("shape mismatch", param_value.shape, param_grad.shape) self.device_check(param_value, self.step_counter, self.lr_value, self.rho_value, self.epsilon_value, self.decay_value) # if self.decay_value != 0: if self.weight_decay.init_value != 0: singa.Axpy(self.decay_value.data, param_value.data, param_grad.data) if param_name not in self.running_average: flag = param_value.device.graph_enabled() param_value.device.EnableGraph(False) self.running_average[param_name] = tensor.zeros_like(param_value) param_value.device.EnableGraph(flag) # running_average = running_average * rho + param_grad * param_grad * (1 - rho) # param_value = param_value - lr * param_grad / sqrt(running_average + epsilon) self.running_average[param_name] *= self.rho_value tmp1 = singa.Square(param_grad.data) tmp2 = 1.0 - self.rho_value singa.Axpy(tmp2.data, tmp1, self.running_average[param_name].data) minus_lr = 0.0 - self.lr_value tmp3 = self.running_average[param_name] + self.epsilon_value tmp3 = singa.Sqrt(tmp3.data) tmp3 = singa.__div__(param_grad.data, tmp3) singa.Axpy(minus_lr.data, tmp3, param_value.data)
def update(self, param, grad): """Performs a single optimization step. Arguments: param(Tensor): param values to be update in-place grad(Tensor): param gradients; the values may be updated in this function; cannot use it anymore """ assert param.shape == grad.shape, ("shape mismatch", param.shape, grad.shape) group = self.default_config if param in self.param2config: group = self.param2config[param] weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] if weight_decay != 0: grad += param * weight_decay if momentum != 0: if param not in self.param2state: self.param2state[param] = {} param_state = self.param2state[param] if 'momentum_buffer' not in param_state: buf = param_state[ 'momentum_buffer'] = tensor.zeros_like(param) buf *= momentum buf += grad else: buf = param_state['momentum_buffer'] buf *= momentum buf += (1 - dampening) * grad if nesterov: grad += momentum * buf else: grad = buf param -= grad * group['lr']
def apply(self, param_name, param_value, param_grad): """Performs a single optimization step. Args: param_name(String): the name of the param param_value(Tensor): param values to be update in-place grad(Tensor): param gradients; the values may be updated in this function; cannot use it anymore """ assert param_value.shape == param_grad.shape, ("shape mismatch", param_value.shape, param_grad.shape) self.device_check(param_value, self.step_counter, self.lr_value, self.beta_1_value, self.beta_2_value, self.epsilon_value, self.decay_value) # if self.decay_value != 0: if self.weight_decay.init_value != 0: singa.Axpy(self.decay_value.data, param_value.data, param_grad.data) if param_name not in self.m: flag = param_value.device.graph_enabled() param_value.device.EnableGraph(False) self.m[param_name] = tensor.zeros_like(param_value) self.v[param_name] = tensor.zeros_like(param_value) param_value.device.EnableGraph(flag) # overall steps # m := beta_1 * m + (1 - beta_1) * grad # v := beta_2 * v + (1 - beta_2) * grad * grad # m_norm = m / (1 - beta_1 ^ step) # v_norm = v / (1 - beta_2 ^ step) # param := param - (lr * m_norm) / ( sqrt(v_norm) + epsilon) ) step = self.step_counter + 1.0 # m := beta_1 * m + (1 - beta_1) * grad tmp = 1.0 - self.beta_1_value self.m[param_name] *= self.beta_1_value singa.Axpy(tmp.data, param_grad.data, self.m[param_name].data) # v := beta_2 * v + (1 - beta_2) * grad * grad tmp = 1.0 - self.beta_2_value self.v[param_name] *= self.beta_2_value singa.Axpy(tmp.data, singa.Square(param_grad.data), self.v[param_name].data) # m_norm = m / (1 - beta_1 ^ step) tmp = tensor.pow(self.beta_1_value, step) tmp = 1.0 - tmp m_norm = self.m[param_name] / tmp # v_norm = v / (1 - beta_2 ^ step) tmp = tensor.pow(self.beta_2_value, step) tmp = 1.0 - tmp v_norm = self.v[param_name] / tmp # param := param - (lr * m_norm) / ( sqrt(v_norm) + epsilon) ) a = tensor.sqrt(v_norm) + self.epsilon_value tmp = m_norm / a minus_lr = 0.0 - self.lr_value singa.Axpy(minus_lr.data, tmp.data, param_value.data)