def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or self._fallback_apply_state(var_device, var_dtype)) acc = self.get_slot(var, 'accumulator') return training_ops.resource_sparse_apply_adagrad_v2( var.handle, acc.handle, coefficients['lr_t'], coefficients['epsilon'], grad, indices, use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or self._fallback_apply_state(var_device, var_dtype)) acc = self.get_slot(var, 'accumulator') if compat.forward_compatible(2019, 8, 20): return training_ops.resource_sparse_apply_adagrad_v2( var.handle, acc.handle, coefficients['lr_t'], coefficients['epsilon'], grad, indices, use_locking=self._use_locking) with ops.control_dependencies([ resource_variable_ops.resource_scatter_add( acc.handle, indices, math_ops.square(grad)) ]): acc_t_slice = acc.sparse_read(indices) var_update = resource_variable_ops.resource_scatter_add( var.handle, indices, coefficients['neg_lr_t'] * grad / (math_ops.sqrt(acc_t_slice) + coefficients['epsilon'])) return var_update