def test_layer(self):
     grassmannian = Grassmannian()
     with self.cached_session(use_gpu=True):
         inp = tf.keras.Input((5, 2))
         layer = ManifoldEmbedding(1000, 64, manifold=grassmannian)
         _ = layer(inp)
         self.assertEqual(
             type(get_manifold(layer.embeddings)), type(grassmannian)
         )
         config = layer.get_config()
         from_config = ManifoldEmbedding(**config)
         _ = from_config(inp)
         self.assertEqual(
             type(get_manifold(from_config.embeddings)), type(grassmannian)
         )
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = (apply_state or {}).get(
            (var_device, var_dtype)) or self._fallback_apply_state(
                var_device, var_dtype)

        manifold = get_manifold(var)
        grad = manifold.egrad2rgrad(var, grad)

        if self._momentum:
            momentum = self.get_slot(var, "momentum")
            momentum_t = momentum * self._momentum - grad * coefficients["lr_t"]
            if self.nesterov:
                var_t = manifold.retr(
                    var,
                    momentum_t * self._momentum - grad * coefficients["lr_t"],
                )
            else:
                var_t = manifold.retr(var, momentum_t)
            momentum.assign(manifold.transp(var, var_t, momentum_t))
            var.assign(var_t)
        else:
            var.assign(manifold.retr(var, -grad * coefficients["lr_t"]))

        if self.stabilize is not None:
            self._stabilize(var)
 def _stabilize(self, var):
     if math_ops.floor_mod(self.iterations, self.stabilize) == 0:
         manifold = get_manifold(var)
         var.assign(manifold.projx(var))
         if self._momentum:
             momentum = self.get_slot(var, "momentum")
             momentum.assign(manifold.proju(var, momentum))
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = (apply_state or {}).get(
            (var_device, var_dtype)) or self._fallback_apply_state(
                var_device, var_dtype)

        m = self.get_slot(var, "m")
        v = self.get_slot(var, "v")

        manifold = get_manifold(var)
        grad = manifold.egrad2rgrad(var, grad)

        alpha = (coefficients["lr_t"] *
                 math_ops.sqrt(1 - coefficients["beta_2_power"]) /
                 (1 - coefficients["beta_1_power"]))
        m.assign_add((grad - m) * (1 - coefficients["beta_1_t"]))
        v.assign_add((manifold.inner(var, grad, grad, keepdims=True) - v) *
                     (1 - coefficients["beta_2_t"]))

        if self.amsgrad:
            vhat = self.get_slot(var, "vhat")
            vhat.assign(math_ops.maximum(vhat, v))
            v = vhat
        var_t = manifold.retr(
            var, -(m * alpha) / (math_ops.sqrt(v) + coefficients["epsilon"]))
        m.assign(manifold.transp(var, var_t, m))
        var.assign(var_t)

        if self.stabilize is not None:
            self._stabilize(var)
示例#5
0
 def test_variable(self):
     euclidean = Euclidean()
     grassmannian = Grassmannian()
     with self.cached_session(use_gpu=True):
         var_1x2 = tf.Variable([[5, 3]])
         var_2x1 = tf.Variable([[5], [3]])
         self.assertEqual(type(variable.get_manifold(var_1x2)),
                          type(euclidean))
         self.assertEqual(
             type(variable.get_manifold(var_1x2)),
             type(variable.get_manifold(var_2x1)),
         )
         with self.assertRaises(ValueError):
             variable.assign_to_manifold(var_1x2, grassmannian)
         variable.assign_to_manifold(var_2x1, grassmannian)
         self.assertEqual(type(variable.get_manifold(var_2x1)),
                          type(grassmannian))
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = (apply_state or {}).get(
            (var_device, var_dtype)) or self._fallback_apply_state(
                var_device, var_dtype)

        manifold = get_manifold(var)
        grad = manifold.egrad2rgrad(var, grad)

        # m_t = beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
        m_t_values = (array_ops.gather(m, indices) * coefficients["beta_1_t"] +
                      m_scaled_g_values)

        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
        v = self.get_slot(var, "v")
        v_scaled_g_values = (manifold.inner(var, grad, grad, keepdims=True) *
                             coefficients["one_minus_beta_2_t"])
        v_t_values = (array_ops.gather(v, indices) * coefficients["beta_2_t"] +
                      v_scaled_g_values)

        if self.amsgrad:
            vhat = self.get_slot(var, "vhat")
            vhat.scatter_max(ops.IndexedSlices(v_t_values, indices))
            v_t_values = array_ops.gather(vhat, indices)

        var_values = array_ops.gather(var, indices)
        var_t_values = manifold.retr(
            var_values,
            -(m_t_values * coefficients["lr"]) /
            (math_ops.sqrt(v_t_values) + coefficients["epsilon"]),
        )
        m_t_transp = manifold.transp(var_values, var_t_values, m_t_values)

        m.scatter_update(ops.IndexedSlices(m_t_transp, indices))
        v.scatter_update(ops.IndexedSlices(v_t_values, indices))
        var.scatter_update(ops.IndexedSlices(var_t_values, indices))

        if self.stabilize is not None:
            self._stabilize(var)
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = (apply_state or {}).get(
            (var_device, var_dtype)) or self._fallback_apply_state(
                var_device, var_dtype)

        manifold = get_manifold(var)
        grad = manifold.egrad2rgrad(var, grad)

        var_values = array_ops.gather(var, indices)

        if self._momentum:
            momentum = self.get_slot(var, "momentum")
            momentum_t_values = (
                array_ops.gather(momentum, indices) * self._momentum -
                grad * coefficients["lr_t"])
            if self.nesterov:
                var_t_values = manifold.retr(
                    var_values,
                    momentum_t_values * self._momentum -
                    grad * coefficients["lr_t"],
                )
            else:
                var_t_values = manifold.retr(var_values, momentum_t_values)
            momentum_transp_values = manifold.transp(var_values, var_t_values,
                                                     momentum_t_values)
            momentum.scatter_update(
                ops.IndexedSlices(momentum_transp_values, indices))
        else:
            var_t_values = manifold.retr(var_values,
                                         -grad * coefficients["lr_t"])

        var.scatter_update(ops.IndexedSlices(var_t_values, indices))

        if self.stabilize is not None:
            self._stabilize(var)
 def _stabilize(self, var):
     if math_ops.floor_mod(self.iterations, self.stabilize) == 0:
         manifold = get_manifold(var)
         m = self.get_slot(var, "m")
         var.assign(manifold.projx(var))
         m.assign(manifold.proju(var, m))