예제 #1
0
    def _process_one_batch(
        self, model: MultiRelationEmbedder, batch_edges: EdgeList
    ) -> Stats:
        model.zero_grad()

        scores, reg = model(batch_edges)

        loss = self.calc_loss(scores, batch_edges)

        stats = Stats(
            loss=float(loss),
            reg=float(reg) if reg is not None else 0.0,
            violators_lhs=int((scores.lhs_neg > scores.lhs_pos.unsqueeze(1)).sum()),
            violators_rhs=int((scores.rhs_neg > scores.rhs_pos.unsqueeze(1)).sum()),
            count=len(batch_edges),
        )
        if reg is not None:
            (loss + reg).backward()
        else:
            loss.backward()
        self.model_optimizer.step(closure=None)
        for optimizer in self.unpartitioned_optimizers.values():
            optimizer.step(closure=None)
        for optimizer in self.partitioned_optimizers.values():
            optimizer.step(closure=None)

        return stats
예제 #2
0
    def process_one_batch(
        self,
        model: MultiRelationEmbedder,
        batch_edges: EdgeList,
    ) -> Stats:
        model.zero_grad()

        scores = model(batch_edges)

        lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg)
        rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg)
        relation = self.relations[batch_edges.get_relation_type_as_scalar(
        ) if batch_edges.has_scalar_relation_type() else 0]
        loss = relation.weight * (lhs_loss + rhs_loss)

        stats = Stats(
            loss=float(loss),
            violators_lhs=int(
                (scores.lhs_neg > scores.lhs_pos.unsqueeze(1)).sum()),
            violators_rhs=int(
                (scores.rhs_neg > scores.rhs_pos.unsqueeze(1)).sum()),
            count=len(batch_edges))

        loss.backward()
        self.global_optimizer.step(closure=None)
        for optimizer in self.entity_optimizers.values():
            optimizer.step(closure=None)

        return stats