Exemplo n.º 1
0
    def _check_constraints(self):
        """Check model constraints.

        * Entity and relation embeddings have to have at most unit L2 norm.
        * Covariances have to have values between c_min and c_max
        """
        for embedding in (self.model.entity_embeddings, self.model.relation_embeddings):
            assert all_in_bounds(embedding(indices=None).norm(p=2, dim=-1), high=1., a_tol=EPSILON)
        for cov in (self.model.entity_covariances, self.model.relation_covariances):
            assert all_in_bounds(cov(indices=None), low=self.model.c_min, high=self.model.c_max)
Exemplo n.º 2
0
    def _check_constraints(self):
        """Check model constraints.

        Entity and relation embeddings have to have at most unit L2 norm.
        """
        for emb in (self.model.entity_embeddings, self.model.relation_embeddings):
            assert all_in_bounds(emb.weight.norm(p=2, dim=-1), high=1., a_tol=1.0e-06)
Exemplo n.º 3
0
    def _check_constraints(self):
        """Check model constraints.

        Entity embeddings have to have at most unit L2 norm.
        """
        assert all_in_bounds(self.model.entity_embeddings.weight.norm(p=2,
                                                                      dim=-1),
                             high=1.)
Exemplo n.º 4
0
    def _check_constraints(self):
        """Check model constraints.

        Entity embeddings have to have at most unit L2 norm.
        """
        assert all_in_bounds(
            self.instance.entity_embeddings(indices=None).norm(p=2, dim=-1),
            high=1.0,
            a_tol=EPSILON)
Exemplo n.º 5
0
    def _check_constraints(self):
        """Check model constraints.

        Entity and relation embeddings have to have at most unit L2 norm.
        """
        for emb in (self.instance.entity_embeddings,
                    self.instance.relation_embeddings):
            assert all_in_bounds(emb(indices=None).norm(p=2, dim=-1),
                                 high=1.,
                                 a_tol=EPSILON)