def forward( self, pos_scores: FloatTensorType, neg_scores: FloatTensorType, ) -> FloatTensorType: # print(pos_scores.size()) num_pos = match_shape(pos_scores, -1) # print(pos_scores.size()) num_neg = match_shape(neg_scores, num_pos, -1) # print("%d --- %d",num_pos,num_neg) # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870 # and https://github.com/pytorch/pytorch/issues/15223. if num_pos == 0 or num_neg == 0: return torch.zeros((), requires_grad=True) # print("---", pos_scores.size()) # print("---", neg_scores.size()) scores = torch.cat([ pos_scores.unsqueeze(1), neg_scores.logsumexp(dim=1, keepdim=True) ], dim=1) # print("scores ", type(scores)) # x = torch.zeros((), dtype=torch.long).expand(num_pos) # print("x= ", x) loss = F.cross_entropy( scores, torch.zeros((), dtype=torch.long).expand(num_pos), reduction='sum', ) # print("loss", loss) return loss
def forward( self, pos_scores: FloatTensorType, neg_scores: FloatTensorType, weight: Optional[FloatTensorType], ) -> FloatTensorType: num_pos = match_shape(pos_scores, -1) num_neg = match_shape(neg_scores, num_pos, -1) neg_weight = 1 / num_neg if num_neg > 0 else 0 if weight is not None: match_shape(weight, num_pos) pos_loss = F.binary_cross_entropy_with_logits( pos_scores, pos_scores.new_ones(()).expand(num_pos), reduction="sum", weight=weight, ) neg_loss = F.binary_cross_entropy_with_logits( neg_scores, neg_scores.new_zeros(()).expand(num_pos, num_neg), reduction="sum", weight=weight.unsqueeze(-1) if weight is not None else None, ) loss = pos_loss + neg_weight * neg_loss return loss
def forward( self, pos_scores: FloatTensorType, neg_scores: FloatTensorType, ) -> FloatTensorType: num_pos = match_shape(pos_scores, -1) num_neg = match_shape(neg_scores, num_pos, -1) # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870 # and https://github.com/pytorch/pytorch/issues/15223. if num_pos == 0 or num_neg == 0: return torch.zeros((), device=pos_scores.device, requires_grad=True) scores = torch.cat([ pos_scores.unsqueeze(1), neg_scores.logsumexp(dim=1, keepdim=True) ], dim=1) loss = F.cross_entropy( scores, pos_scores.new_zeros((), dtype=torch.long).expand(num_pos), reduction='sum', ) return loss
def test_zero_dimensions(self): t = torch.zeros(()) self.assertIsNone(match_shape(t)) self.assertIsNone(match_shape(t, ...)) with self.assertRaises(TypeError): match_shape(t, 0) with self.assertRaises(TypeError): match_shape(t, 1) with self.assertRaises(TypeError): match_shape(t, -1)
def forward(self, pos_scores: FloatTensorType, neg_scores: FloatTensorType) -> FloatTensorType: num_pos = match_shape(pos_scores, -1) num_neg = match_shape(neg_scores, num_pos, -1) neg_weight = 1 / num_neg if num_neg > 0 else 0 pos_loss = F.binary_cross_entropy_with_logits(pos_scores, pos_scores.new_ones( ()).expand(num_pos), reduction="sum") neg_loss = F.binary_cross_entropy_with_logits( neg_scores, neg_scores.new_zeros(()).expand(num_pos, num_neg), reduction="sum", ) loss = pos_loss + neg_weight * neg_loss return loss
def forward(self, pos_scores: FloatTensorType, neg_scores: FloatTensorType) -> FloatTensorType: num_pos = match_shape(pos_scores, -1) num_neg = match_shape(neg_scores, num_pos, -1) # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15223. if num_pos == 0 or num_neg == 0: return torch.zeros((), device=pos_scores.device, requires_grad=True) loss = F.margin_ranking_loss( neg_scores, pos_scores.unsqueeze(1), target=pos_scores.new_full((1, 1), -1, dtype=torch.float), margin=self.margin, reduction="sum", ) return loss
def test_bad_args(self): t = torch.empty((0, )) with self.assertRaises(RuntimeError): match_shape(t, ..., ...) with self.assertRaises(RuntimeError): match_shape(t, "foo") with self.assertRaises(AttributeError): match_shape(None)
def forward( self, pos_scores: FloatTensorType, neg_scores: FloatTensorType, ) -> FloatTensorType: num_pos = match_shape(pos_scores, -1) num_neg = match_shape(neg_scores, num_pos, -1) neg_weight = 1 / num_neg if num_neg > 0 else 0 # print(pos_scores.size()) pos_loss = F.binary_cross_entropy_with_logits( pos_scores, torch.ones(()).expand(num_pos), reduction='sum', ) neg_loss = F.binary_cross_entropy_with_logits( neg_scores, torch.zeros(()).expand(num_pos, num_neg), reduction='sum', ) loss = pos_loss + neg_weight * neg_loss return loss
def forward( self, pos_scores: FloatTensorType, neg_scores: FloatTensorType, weight: Optional[FloatTensorType], ) -> FloatTensorType: num_pos = match_shape(pos_scores, -1) num_neg = match_shape(neg_scores, num_pos, -1) # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870 # and https://github.com/pytorch/pytorch/issues/15223. if num_pos == 0 or num_neg == 0: return torch.zeros((), device=pos_scores.device, requires_grad=True) scores = torch.cat([ pos_scores.unsqueeze(1), neg_scores.logsumexp(dim=1, keepdim=True) ], dim=1) if weight is not None: loss_per_sample = F.cross_entropy( scores, pos_scores.new_zeros((), dtype=torch.long).expand(num_pos), reduction="none", ) match_shape(weight, num_pos) loss_per_sample = loss_per_sample * weight else: loss_per_sample = F.cross_entropy( scores, pos_scores.new_zeros((), dtype=torch.long).expand(num_pos), reduction="sum", ) return loss_per_sample.sum()
def forward( self, pos_scores: FloatTensorType, neg_scores: FloatTensorType, weight: Optional[FloatTensorType], ) -> FloatTensorType: num_pos = match_shape(pos_scores, -1) num_neg = match_shape(neg_scores, num_pos, -1) # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15223. if num_pos == 0 or num_neg == 0: return torch.zeros((), device=pos_scores.device, requires_grad=True) if weight is not None: match_shape(weight, num_pos) loss_per_sample = F.margin_ranking_loss( neg_scores, pos_scores.unsqueeze(1), target=pos_scores.new_full((1, 1), -1, dtype=torch.float), margin=self.margin, reduction="none", ) loss = (loss_per_sample * weight.unsqueeze(-1)).sum() else: # more memory efficient way if no weights loss = F.margin_ranking_loss( neg_scores, pos_scores.unsqueeze(1), target=pos_scores.new_full((1, 1), -1, dtype=torch.float), margin=self.margin, reduction="sum", ) return loss
def test_many_dimension(self): t = torch.zeros((3, 4, 5)) self.assertIsNone(match_shape(t, 3, 4, 5)) self.assertIsNone(match_shape(t, ...)) self.assertIsNone(match_shape(t, ..., 5)) self.assertIsNone(match_shape(t, 3, ..., 5)) self.assertIsNone(match_shape(t, 3, 4, 5, ...)) self.assertEqual(match_shape(t, -1, 4, 5), 3) self.assertEqual(match_shape(t, -1, ...), 3) self.assertEqual(match_shape(t, -1, 4, ...), 3) self.assertEqual(match_shape(t, -1, ..., 5), 3) self.assertEqual(match_shape(t, -1, 4, -1), (3, 5)) self.assertEqual(match_shape(t, ..., -1, -1), (4, 5)) self.assertEqual(match_shape(t, -1, -1, -1), (3, 4, 5)) self.assertEqual(match_shape(t, -1, -1, ..., -1), (3, 4, 5)) with self.assertRaises(TypeError): match_shape(t) with self.assertRaises(TypeError): match_shape(t, 3) with self.assertRaises(TypeError): match_shape(t, 3, 4) with self.assertRaises(TypeError): match_shape(t, 5, 4, 3) with self.assertRaises(TypeError): match_shape(t, 3, 4, 5, 6) with self.assertRaises(TypeError): match_shape(t, 3, 4, ..., 4, 5)
def test_one_dimension(self): t = torch.zeros((3, )) self.assertIsNone(match_shape(t, 3)) self.assertIsNone(match_shape(t, ...)) self.assertIsNone(match_shape(t, 3, ...)) self.assertIsNone(match_shape(t, ..., 3)) self.assertEqual(match_shape(t, -1), 3) with self.assertRaises(TypeError): match_shape(t) with self.assertRaises(TypeError): match_shape(t, 3, 1) with self.assertRaises(TypeError): match_shape(t, 3, ..., 3)