def _resource_apply_sparse(self, grad, var, indices): rms = self.get_slot(var, "rms") mom = self.get_slot(var, "momentum") if self._centered: mg = self.get_slot(var, "mg") return training_ops.resource_sparse_apply_centered_rms_prop( var.handle, mg.handle, rms.handle, mom.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype), math_ops.cast(self._decay_tensor, grad.dtype), math_ops.cast(self._momentum_tensor, grad.dtype), math_ops.cast(self._epsilon_tensor, grad.dtype), grad, indices, use_locking=self._use_locking) else: return training_ops.resource_sparse_apply_rms_prop( var.handle, rms.handle, mom.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype), math_ops.cast(self._decay_tensor, grad.dtype), math_ops.cast(self._momentum_tensor, grad.dtype), math_ops.cast(self._epsilon_tensor, grad.dtype), grad, indices, use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices, state): rms = state.get_slot(var, "rms") mom = state.get_slot(var, "momentum") if self._centered: mg = self.get_slot(var, "mg") return training_ops.resource_sparse_apply_centered_rms_prop( var.handle, mg.handle, rms.handle, mom.handle, state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("rho", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), 0, grad, indices, use_locking=self._use_locking) else: return training_ops.resource_sparse_apply_rms_prop( var.handle, rms.handle, mom.handle, state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("rho", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), 0, grad, indices, use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) rms = self.get_slot(var, "rms") mom = self.get_slot(var, "momentum") rho = self._get_hyper("rho", var_dtype) momentum = self._get_hyper("momentum", var_dtype) epsilon = self._get_hyper("epsilon", var_dtype) if self.centered: mg = self.get_slot(var, "mg") return training_ops.resource_sparse_apply_centered_rms_prop( var.handle, mg.handle, rms.handle, mom.handle, lr_t, rho, momentum, epsilon, grad, indices, use_locking=self._use_locking) else: return training_ops.resource_sparse_apply_rms_prop( var.handle, rms.handle, mom.handle, lr_t, rho, momentum, epsilon, grad, indices, use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) rms = self.get_slot(var, "rms") mom = self.get_slot(var, "momentum") rho = self._get_hyper("rho", var_dtype) momentum = self._get_hyper("momentum", var_dtype) epsilon = self._get_hyper("epsilon", var_dtype) if self._centered: mg = self.get_slot(var, "mg") return training_ops.resource_sparse_apply_centered_rms_prop( var.handle, mg.handle, rms.handle, mom.handle, lr_t, rho, momentum, epsilon, grad, indices, use_locking=self._use_locking) else: return training_ops.resource_sparse_apply_rms_prop( var.handle, rms.handle, mom.handle, lr_t, rho, momentum, epsilon, grad, indices, use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices): rms = self.get_slot(var, "rms") mom = self.get_slot(var, "momentum") learning_rate = math_ops.cast( self._get_hyper("learning_rate"), grad.dtype.base_dtype) rho = math_ops.cast(self._get_hyper("rho"), grad.dtype.base_dtype) momentum = math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype) epsilon = math_ops.cast(self._get_hyper("epsilon"), grad.dtype.base_dtype) if self._centered: mg = self.get_slot(var, "mg") return training_ops.resource_sparse_apply_centered_rms_prop( var.handle, mg.handle, rms.handle, mom.handle, learning_rate, rho, momentum, epsilon, grad, indices, use_locking=self._use_locking) else: return training_ops.resource_sparse_apply_rms_prop( var.handle, rms.handle, mom.handle, learning_rate, rho, momentum, epsilon, grad, indices, use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or self._fallback_apply_state(var_device, var_dtype)) rms = self.get_slot(var, "rms") if self._momentum: mom = self.get_slot(var, "momentum") if self.centered: mg = self.get_slot(var, "mg") return training_ops.resource_sparse_apply_centered_rms_prop( var.handle, mg.handle, rms.handle, mom.handle, coefficients["lr_t"], coefficients["rho"], coefficients["momentum"], coefficients["epsilon"], grad, indices, use_locking=self._use_locking) else: return training_ops.resource_sparse_apply_rms_prop( var.handle, rms.handle, mom.handle, coefficients["lr_t"], coefficients["rho"], coefficients["momentum"], coefficients["epsilon"], grad, indices, use_locking=self._use_locking) else: rms_scaled_g_values = (grad * grad) * coefficients["one_minus_rho"] rms_t = state_ops.assign(rms, rms * coefficients["rho"], use_locking=self._use_locking) with ops.control_dependencies([rms_t]): rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values) rms_slice = array_ops.gather(rms_t, indices) denom_slice = rms_slice if self.centered: mg = self.get_slot(var, "mg") mg_scaled_g_values = grad * coefficients["one_minus_rho"] mg_t = state_ops.assign(mg, mg * coefficients["rho"], use_locking=self._use_locking) with ops.control_dependencies([mg_t]): mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values) mg_slice = array_ops.gather(mg_t, indices) denom_slice = rms_slice - math_ops.square(mg_slice) var_update = self._resource_scatter_add( var, indices, coefficients["neg_lr_t"] * grad / ( math_ops.sqrt(denom_slice) + coefficients["epsilon"])) if self.centered: return control_flow_ops.group(*[var_update, rms_t, mg_t]) return control_flow_ops.group(*[var_update, rms_t])
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) rms = self.get_slot(var, "rms") rho = self._get_hyper("rho", var_dtype) momentum = self._get_hyper("momentum", var_dtype) epsilon = self._get_hyper("epsilon", var_dtype) if self._momentum: mom = self.get_slot(var, "momentum") if self.centered: mg = self.get_slot(var, "mg") return training_ops.resource_sparse_apply_centered_rms_prop( var.handle, mg.handle, rms.handle, mom.handle, lr_t, rho, momentum, epsilon, grad, indices, use_locking=self._use_locking) else: return training_ops.resource_sparse_apply_rms_prop( var.handle, rms.handle, mom.handle, lr_t, rho, momentum, epsilon, grad, indices, use_locking=self._use_locking) else: rms_scaled_g_values = (grad * grad) * (1. - rho) rms_t = state_ops.assign(rms, rms * rho, use_locking=self._use_locking) with ops.control_dependencies([rms_t]): rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values) rms_slice = array_ops.gather(rms_t, indices) denom_slice = rms_slice if self.centered: mg = self.get_slot(var, "mg") mg_scaled_g_values = grad * (1. - rho) mg_t = state_ops.assign(mg, mg * rho, use_locking=self._use_locking) with ops.control_dependencies([mg_t]): mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values) mg_slice = array_ops.gather(mg_t, indices) denom_slice = rms_slice - math_ops.square(mg_slice) var_update = self._resource_scatter_add( var, indices, -lr_t * grad / (math_ops.sqrt(denom_slice) + epsilon)) if self.centered: return control_flow_ops.group(*[var_update, rms_t, mg_t]) return control_flow_ops.group(*[var_update, rms_t])