def _resource_apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") beta1_power = self._get_beta_accumulators() return training_ops.resource_apply_ada_max( var.handle, m.handle, v.handle, math_ops.cast(beta1_power, grad.dtype.base_dtype), math_ops.cast(self._lr_t, grad.dtype.base_dtype), math_ops.cast(self._beta1_t, grad.dtype.base_dtype), math_ops.cast(self._beta2_t, grad.dtype.base_dtype), math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), grad, use_locking=self._use_locking)
def _resource_apply_dense(self, grad, var, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or self._fallback_apply_state(var_device, var_dtype)) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') return training_ops.resource_apply_ada_max( var.handle, m.handle, v.handle, coefficients['beta_1_power'], coefficients['lr_t'], coefficients['beta_1_t'], coefficients['beta_2_t'], coefficients['epsilon'], grad, use_locking=self._use_locking)
def _resource_apply_dense(self, grad, var): grad_dtype = grad.dtype.base_dtype m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') local_step = math_ops.cast(self.iterations + 1, grad_dtype) beta_1_t = math_ops.cast(self._get_hyper('beta_1'), grad_dtype) beta_2_t = math_ops.cast(self._get_hyper('beta_2'), grad_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) return training_ops.resource_apply_ada_max( var.handle, m.handle, v.handle, beta_1_power, math_ops.cast(self._get_hyper('learning_rate'), grad_dtype), beta_1_t, beta_2_t, math_ops.cast(self._get_hyper('epsilon'), grad_dtype), grad, use_locking=self._use_locking)
def _resource_apply_dense(self, grad, var): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) return training_ops.resource_apply_ada_max( var.handle, m.handle, v.handle, beta_1_power, lr_t, beta_1_t, beta_2_t, self._get_hyper('epsilon', var_dtype), grad, use_locking=self._use_locking)