Esempio n. 1
0
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = ((apply_state or {}).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))

        acc = self.get_slot(var, 'accumulator')
        return training_ops.resource_sparse_apply_adagrad_v2(
            var.handle,
            acc.handle,
            coefficients['lr_t'],
            coefficients['epsilon'],
            grad,
            indices,
            use_locking=self._use_locking)
Esempio n. 2
0
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = ((apply_state or {}).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))

        acc = self.get_slot(var, 'accumulator')
        if compat.forward_compatible(2019, 8, 20):
            return training_ops.resource_sparse_apply_adagrad_v2(
                var.handle,
                acc.handle,
                coefficients['lr_t'],
                coefficients['epsilon'],
                grad,
                indices,
                use_locking=self._use_locking)
        with ops.control_dependencies([
                resource_variable_ops.resource_scatter_add(
                    acc.handle, indices, math_ops.square(grad))
        ]):
            acc_t_slice = acc.sparse_read(indices)
        var_update = resource_variable_ops.resource_scatter_add(
            var.handle, indices, coefficients['neg_lr_t'] * grad /
            (math_ops.sqrt(acc_t_slice) + coefficients['epsilon']))
        return var_update