示例#1
0
文件: losses.py 项目: RweBs/PDKE
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
    ) -> FloatTensorType:

        # print(pos_scores.size())
        num_pos = match_shape(pos_scores, -1)
        # print(pos_scores.size())
        num_neg = match_shape(neg_scores, num_pos, -1)
        # print("%d  ---  %d",num_pos,num_neg)
        # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870
        # and https://github.com/pytorch/pytorch/issues/15223.
        if num_pos == 0 or num_neg == 0:
            return torch.zeros((), requires_grad=True)

        # print("---", pos_scores.size())
        # print("---", neg_scores.size())
        scores = torch.cat([
            pos_scores.unsqueeze(1),
            neg_scores.logsumexp(dim=1, keepdim=True)
        ],
                           dim=1)

        # print("scores ", type(scores))
        # x = torch.zeros((), dtype=torch.long).expand(num_pos)
        # print("x= ", x)
        loss = F.cross_entropy(
            scores,
            torch.zeros((), dtype=torch.long).expand(num_pos),
            reduction='sum',
        )
        # print("loss", loss)
        return loss
示例#2
0
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
        weight: Optional[FloatTensorType],
    ) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)
        neg_weight = 1 / num_neg if num_neg > 0 else 0

        if weight is not None:
            match_shape(weight, num_pos)
        pos_loss = F.binary_cross_entropy_with_logits(
            pos_scores,
            pos_scores.new_ones(()).expand(num_pos),
            reduction="sum",
            weight=weight,
        )
        neg_loss = F.binary_cross_entropy_with_logits(
            neg_scores,
            neg_scores.new_zeros(()).expand(num_pos, num_neg),
            reduction="sum",
            weight=weight.unsqueeze(-1) if weight is not None else None,
        )

        loss = pos_loss + neg_weight * neg_loss

        return loss
示例#3
0
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
    ) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)

        # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870
        # and https://github.com/pytorch/pytorch/issues/15223.
        if num_pos == 0 or num_neg == 0:
            return torch.zeros((),
                               device=pos_scores.device,
                               requires_grad=True)

        scores = torch.cat([
            pos_scores.unsqueeze(1),
            neg_scores.logsumexp(dim=1, keepdim=True)
        ],
                           dim=1)
        loss = F.cross_entropy(
            scores,
            pos_scores.new_zeros((), dtype=torch.long).expand(num_pos),
            reduction='sum',
        )

        return loss
 def test_zero_dimensions(self):
     t = torch.zeros(())
     self.assertIsNone(match_shape(t))
     self.assertIsNone(match_shape(t, ...))
     with self.assertRaises(TypeError):
         match_shape(t, 0)
     with self.assertRaises(TypeError):
         match_shape(t, 1)
     with self.assertRaises(TypeError):
         match_shape(t, -1)
示例#5
0
    def forward(self, pos_scores: FloatTensorType,
                neg_scores: FloatTensorType) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)
        neg_weight = 1 / num_neg if num_neg > 0 else 0

        pos_loss = F.binary_cross_entropy_with_logits(pos_scores,
                                                      pos_scores.new_ones(
                                                          ()).expand(num_pos),
                                                      reduction="sum")
        neg_loss = F.binary_cross_entropy_with_logits(
            neg_scores,
            neg_scores.new_zeros(()).expand(num_pos, num_neg),
            reduction="sum",
        )
        loss = pos_loss + neg_weight * neg_loss

        return loss
示例#6
0
    def forward(self, pos_scores: FloatTensorType,
                neg_scores: FloatTensorType) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)

        # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15223.
        if num_pos == 0 or num_neg == 0:
            return torch.zeros((),
                               device=pos_scores.device,
                               requires_grad=True)

        loss = F.margin_ranking_loss(
            neg_scores,
            pos_scores.unsqueeze(1),
            target=pos_scores.new_full((1, 1), -1, dtype=torch.float),
            margin=self.margin,
            reduction="sum",
        )

        return loss
 def test_bad_args(self):
     t = torch.empty((0, ))
     with self.assertRaises(RuntimeError):
         match_shape(t, ..., ...)
     with self.assertRaises(RuntimeError):
         match_shape(t, "foo")
     with self.assertRaises(AttributeError):
         match_shape(None)
示例#8
0
文件: losses.py 项目: RweBs/PDKE
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
    ) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)
        neg_weight = 1 / num_neg if num_neg > 0 else 0

        # print(pos_scores.size())
        pos_loss = F.binary_cross_entropy_with_logits(
            pos_scores,
            torch.ones(()).expand(num_pos),
            reduction='sum',
        )
        neg_loss = F.binary_cross_entropy_with_logits(
            neg_scores,
            torch.zeros(()).expand(num_pos, num_neg),
            reduction='sum',
        )
        loss = pos_loss + neg_weight * neg_loss

        return loss
示例#9
0
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
        weight: Optional[FloatTensorType],
    ) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)

        # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870
        # and https://github.com/pytorch/pytorch/issues/15223.
        if num_pos == 0 or num_neg == 0:
            return torch.zeros((),
                               device=pos_scores.device,
                               requires_grad=True)

        scores = torch.cat([
            pos_scores.unsqueeze(1),
            neg_scores.logsumexp(dim=1, keepdim=True)
        ],
                           dim=1)
        if weight is not None:
            loss_per_sample = F.cross_entropy(
                scores,
                pos_scores.new_zeros((), dtype=torch.long).expand(num_pos),
                reduction="none",
            )
            match_shape(weight, num_pos)
            loss_per_sample = loss_per_sample * weight
        else:
            loss_per_sample = F.cross_entropy(
                scores,
                pos_scores.new_zeros((), dtype=torch.long).expand(num_pos),
                reduction="sum",
            )

        return loss_per_sample.sum()
示例#10
0
    def forward(
        self,
        pos_scores: FloatTensorType,
        neg_scores: FloatTensorType,
        weight: Optional[FloatTensorType],
    ) -> FloatTensorType:
        num_pos = match_shape(pos_scores, -1)
        num_neg = match_shape(neg_scores, num_pos, -1)

        # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15223.
        if num_pos == 0 or num_neg == 0:
            return torch.zeros((),
                               device=pos_scores.device,
                               requires_grad=True)

        if weight is not None:
            match_shape(weight, num_pos)
            loss_per_sample = F.margin_ranking_loss(
                neg_scores,
                pos_scores.unsqueeze(1),
                target=pos_scores.new_full((1, 1), -1, dtype=torch.float),
                margin=self.margin,
                reduction="none",
            )
            loss = (loss_per_sample * weight.unsqueeze(-1)).sum()
        else:
            # more memory efficient way if no weights
            loss = F.margin_ranking_loss(
                neg_scores,
                pos_scores.unsqueeze(1),
                target=pos_scores.new_full((1, 1), -1, dtype=torch.float),
                margin=self.margin,
                reduction="sum",
            )

        return loss
示例#11
0
 def test_many_dimension(self):
     t = torch.zeros((3, 4, 5))
     self.assertIsNone(match_shape(t, 3, 4, 5))
     self.assertIsNone(match_shape(t, ...))
     self.assertIsNone(match_shape(t, ..., 5))
     self.assertIsNone(match_shape(t, 3, ..., 5))
     self.assertIsNone(match_shape(t, 3, 4, 5, ...))
     self.assertEqual(match_shape(t, -1, 4, 5), 3)
     self.assertEqual(match_shape(t, -1, ...), 3)
     self.assertEqual(match_shape(t, -1, 4, ...), 3)
     self.assertEqual(match_shape(t, -1, ..., 5), 3)
     self.assertEqual(match_shape(t, -1, 4, -1), (3, 5))
     self.assertEqual(match_shape(t, ..., -1, -1), (4, 5))
     self.assertEqual(match_shape(t, -1, -1, -1), (3, 4, 5))
     self.assertEqual(match_shape(t, -1, -1, ..., -1), (3, 4, 5))
     with self.assertRaises(TypeError):
         match_shape(t)
     with self.assertRaises(TypeError):
         match_shape(t, 3)
     with self.assertRaises(TypeError):
         match_shape(t, 3, 4)
     with self.assertRaises(TypeError):
         match_shape(t, 5, 4, 3)
     with self.assertRaises(TypeError):
         match_shape(t, 3, 4, 5, 6)
     with self.assertRaises(TypeError):
         match_shape(t, 3, 4, ..., 4, 5)
示例#12
0
 def test_one_dimension(self):
     t = torch.zeros((3, ))
     self.assertIsNone(match_shape(t, 3))
     self.assertIsNone(match_shape(t, ...))
     self.assertIsNone(match_shape(t, 3, ...))
     self.assertIsNone(match_shape(t, ..., 3))
     self.assertEqual(match_shape(t, -1), 3)
     with self.assertRaises(TypeError):
         match_shape(t)
     with self.assertRaises(TypeError):
         match_shape(t, 3, 1)
     with self.assertRaises(TypeError):
         match_shape(t, 3, ..., 3)