def test_layer(self): grassmannian = Grassmannian() with self.cached_session(use_gpu=True): inp = tf.keras.Input((5, 2)) layer = ManifoldEmbedding(1000, 64, manifold=grassmannian) _ = layer(inp) self.assertEqual( type(get_manifold(layer.embeddings)), type(grassmannian) ) config = layer.get_config() from_config = ManifoldEmbedding(**config) _ = from_config(inp) self.assertEqual( type(get_manifold(from_config.embeddings)), type(grassmannian) )
def _resource_apply_dense(self, grad, var, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype)) or self._fallback_apply_state( var_device, var_dtype) manifold = get_manifold(var) grad = manifold.egrad2rgrad(var, grad) if self._momentum: momentum = self.get_slot(var, "momentum") momentum_t = momentum * self._momentum - grad * coefficients["lr_t"] if self.nesterov: var_t = manifold.retr( var, momentum_t * self._momentum - grad * coefficients["lr_t"], ) else: var_t = manifold.retr(var, momentum_t) momentum.assign(manifold.transp(var, var_t, momentum_t)) var.assign(var_t) else: var.assign(manifold.retr(var, -grad * coefficients["lr_t"])) if self.stabilize is not None: self._stabilize(var)
def _stabilize(self, var): if math_ops.floor_mod(self.iterations, self.stabilize) == 0: manifold = get_manifold(var) var.assign(manifold.projx(var)) if self._momentum: momentum = self.get_slot(var, "momentum") momentum.assign(manifold.proju(var, momentum))
def _resource_apply_dense(self, grad, var, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype)) or self._fallback_apply_state( var_device, var_dtype) m = self.get_slot(var, "m") v = self.get_slot(var, "v") manifold = get_manifold(var) grad = manifold.egrad2rgrad(var, grad) alpha = (coefficients["lr_t"] * math_ops.sqrt(1 - coefficients["beta_2_power"]) / (1 - coefficients["beta_1_power"])) m.assign_add((grad - m) * (1 - coefficients["beta_1_t"])) v.assign_add((manifold.inner(var, grad, grad, keepdims=True) - v) * (1 - coefficients["beta_2_t"])) if self.amsgrad: vhat = self.get_slot(var, "vhat") vhat.assign(math_ops.maximum(vhat, v)) v = vhat var_t = manifold.retr( var, -(m * alpha) / (math_ops.sqrt(v) + coefficients["epsilon"])) m.assign(manifold.transp(var, var_t, m)) var.assign(var_t) if self.stabilize is not None: self._stabilize(var)
def test_variable(self): euclidean = Euclidean() grassmannian = Grassmannian() with self.cached_session(use_gpu=True): var_1x2 = tf.Variable([[5, 3]]) var_2x1 = tf.Variable([[5], [3]]) self.assertEqual(type(variable.get_manifold(var_1x2)), type(euclidean)) self.assertEqual( type(variable.get_manifold(var_1x2)), type(variable.get_manifold(var_2x1)), ) with self.assertRaises(ValueError): variable.assign_to_manifold(var_1x2, grassmannian) variable.assign_to_manifold(var_2x1, grassmannian) self.assertEqual(type(variable.get_manifold(var_2x1)), type(grassmannian))
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype)) or self._fallback_apply_state( var_device, var_dtype) manifold = get_manifold(var) grad = manifold.egrad2rgrad(var, grad) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] m_t_values = (array_ops.gather(m, indices) * coefficients["beta_1_t"] + m_scaled_g_values) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_scaled_g_values = (manifold.inner(var, grad, grad, keepdims=True) * coefficients["one_minus_beta_2_t"]) v_t_values = (array_ops.gather(v, indices) * coefficients["beta_2_t"] + v_scaled_g_values) if self.amsgrad: vhat = self.get_slot(var, "vhat") vhat.scatter_max(ops.IndexedSlices(v_t_values, indices)) v_t_values = array_ops.gather(vhat, indices) var_values = array_ops.gather(var, indices) var_t_values = manifold.retr( var_values, -(m_t_values * coefficients["lr"]) / (math_ops.sqrt(v_t_values) + coefficients["epsilon"]), ) m_t_transp = manifold.transp(var_values, var_t_values, m_t_values) m.scatter_update(ops.IndexedSlices(m_t_transp, indices)) v.scatter_update(ops.IndexedSlices(v_t_values, indices)) var.scatter_update(ops.IndexedSlices(var_t_values, indices)) if self.stabilize is not None: self._stabilize(var)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype)) or self._fallback_apply_state( var_device, var_dtype) manifold = get_manifold(var) grad = manifold.egrad2rgrad(var, grad) var_values = array_ops.gather(var, indices) if self._momentum: momentum = self.get_slot(var, "momentum") momentum_t_values = ( array_ops.gather(momentum, indices) * self._momentum - grad * coefficients["lr_t"]) if self.nesterov: var_t_values = manifold.retr( var_values, momentum_t_values * self._momentum - grad * coefficients["lr_t"], ) else: var_t_values = manifold.retr(var_values, momentum_t_values) momentum_transp_values = manifold.transp(var_values, var_t_values, momentum_t_values) momentum.scatter_update( ops.IndexedSlices(momentum_transp_values, indices)) else: var_t_values = manifold.retr(var_values, -grad * coefficients["lr_t"]) var.scatter_update(ops.IndexedSlices(var_t_values, indices)) if self.stabilize is not None: self._stabilize(var)
def _stabilize(self, var): if math_ops.floor_mod(self.iterations, self.stabilize) == 0: manifold = get_manifold(var) m = self.get_slot(var, "m") var.assign(manifold.projx(var)) m.assign(manifold.proju(var, m))