Example #1
0
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
        weight: Optional[FloatTensorType],
    ) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)
        neg_weight = 1 / num_neg if num_neg > 0 else 0

        if weight is not None:
            match_shape(weight, num_pos)
        pos_loss = F.binary_cross_entropy_with_logits(
            pos_scores,
            pos_scores.new_ones(()).expand(num_pos),
            reduction="sum",
            weight=weight,
        )
        neg_loss = F.binary_cross_entropy_with_logits(
            neg_scores,
            neg_scores.new_zeros(()).expand(num_pos, num_neg),
            reduction="sum",
            weight=weight.unsqueeze(-1) if weight is not None else None,
        )

        loss = pos_loss + neg_weight * neg_loss

        return loss
Example #2
0
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
    ) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)

        # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870
        # and https://github.com/pytorch/pytorch/issues/15223.
        if num_pos == 0 or num_neg == 0:
            return torch.zeros((),
                               device=pos_scores.device,
                               requires_grad=True)

        scores = torch.cat([
            pos_scores.unsqueeze(1),
            neg_scores.logsumexp(dim=1, keepdim=True)
        ],
                           dim=1)
        loss = F.cross_entropy(
            scores,
            pos_scores.new_zeros((), dtype=torch.long).expand(num_pos),
            reduction='sum',
        )

        return loss
Example #3
0
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
        weight: Optional[FloatTensorType],
    ) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)

        # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870
        # and https://github.com/pytorch/pytorch/issues/15223.
        if num_pos == 0 or num_neg == 0:
            return torch.zeros((),
                               device=pos_scores.device,
                               requires_grad=True)

        scores = torch.cat([
            pos_scores.unsqueeze(1),
            neg_scores.logsumexp(dim=1, keepdim=True)
        ],
                           dim=1)
        if weight is not None:
            loss_per_sample = F.cross_entropy(
                scores,
                pos_scores.new_zeros((), dtype=torch.long).expand(num_pos),
                reduction="none",
            )
            match_shape(weight, num_pos)
            loss_per_sample = loss_per_sample * weight
        else:
            loss_per_sample = F.cross_entropy(
                scores,
                pos_scores.new_zeros((), dtype=torch.long).expand(num_pos),
                reduction="sum",
            )

        return loss_per_sample.sum()
Example #4
0
    def forward(self, pos_scores: FloatTensorType,
                neg_scores: FloatTensorType) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)
        neg_weight = 1 / num_neg if num_neg > 0 else 0

        pos_loss = F.binary_cross_entropy_with_logits(pos_scores,
                                                      pos_scores.new_ones(
                                                          ()).expand(num_pos),
                                                      reduction="sum")
        neg_loss = F.binary_cross_entropy_with_logits(
            neg_scores,
            neg_scores.new_zeros(()).expand(num_pos, num_neg),
            reduction="sum",
        )
        loss = pos_loss + neg_weight * neg_loss

        return loss
Example #5
0
    def forward_direction_agnostic(
        self,
        src: EntityList,
        dst: EntityList,
        rel: Union[int, LongTensorType],
        src_entity_type: str,
        dst_entity_type: str,
        src_operator: Union[None, AbstractOperator, AbstractDynamicOperator],
        dst_operator: Union[None, AbstractOperator, AbstractDynamicOperator],
        src_module: AbstractEmbedding,
        dst_module: AbstractEmbedding,
        src_pos: FloatTensorType,
        dst_pos: FloatTensorType,
        chunk_size: int,
        src_negative_sampling_method: Negatives,
        dst_negative_sampling_method: Negatives,
    ):
        num_pos = len(src)
        assert len(dst) == num_pos

        src_pos = self.adjust_embs(src_pos, rel, src_entity_type, src_operator)
        dst_pos = self.adjust_embs(dst_pos, rel, dst_entity_type, dst_operator)

        num_chunks = ceil_of_ratio(num_pos, chunk_size)
        src_dim = src_pos.size(-1)
        dst_dim = dst_pos.size(-1)
        if num_pos < num_chunks * chunk_size:
            src_padding = src_pos.new_zeros(()).expand(
                (num_chunks * chunk_size - num_pos, src_dim))
            src_pos = torch.cat((src_pos, src_padding), dim=0)
            dst_padding = dst_pos.new_zeros(()).expand(
                (num_chunks * chunk_size - num_pos, dst_dim))
            dst_pos = torch.cat((dst_pos, dst_padding), dim=0)
        src_pos = src_pos.view((num_chunks, chunk_size, src_dim))
        dst_pos = dst_pos.view((num_chunks, chunk_size, dst_dim))

        src_neg, src_ignore_mask = self.prepare_negatives(
            src,
            src_pos,
            src_module,
            src_negative_sampling_method,
            self.num_uniform_negs,
            rel,
            src_entity_type,
            src_operator,
        )
        dst_neg, dst_ignore_mask = self.prepare_negatives(
            dst,
            dst_pos,
            dst_module,
            dst_negative_sampling_method,
            self.num_uniform_negs,
            rel,
            dst_entity_type,
            dst_operator,
        )

        pos_scores, src_neg_scores, dst_neg_scores = self.comparator(
            src_pos, dst_pos, src_neg, dst_neg)

        pos_scores = pos_scores.float()
        src_neg_scores = src_neg_scores.float()
        dst_neg_scores = dst_neg_scores.float()

        # The masks tell us which negative scores (i.e., scores for non-existing
        # edges) must be ignored because they come from pairs we don't actually
        # intend to compare (say, positive pairs or interactions with padding).
        # We do it by replacing them with a "very negative" value so that they
        # are considered spot-on predictions with minimal impact on the loss.
        for ignore_mask in src_ignore_mask:
            src_neg_scores[ignore_mask] = -1e9
        for ignore_mask in dst_ignore_mask:
            dst_neg_scores[ignore_mask] = -1e9

        # De-chunk the scores and ignore the ones whose positives were padding.
        pos_scores = pos_scores.flatten(0, 1)[:num_pos]
        src_neg_scores = src_neg_scores.flatten(0, 1)[:num_pos]
        dst_neg_scores = dst_neg_scores.flatten(0, 1)[:num_pos]
        reg = None
        if self.regularizer is not None:
            assert (src_operator is None) != (
                dst_operator is
                None), "Exactly one of src or dst operator should be None"
            operator = src_operator if src_operator is not None else dst_operator
            if self.num_dynamic_rels > 0:
                reg = self.regularizer.forward_dynamic(src_pos, dst_pos,
                                                       operator, rel)
            else:
                reg = self.regularizer.forward(src_pos, dst_pos, operator)

        return pos_scores, src_neg_scores, dst_neg_scores, reg