Esempio n. 1
0
    def test__generate_all_triplets_multiple_negative(self):
        """
        Case where just one positive  example and 1 negative sample
        """
        # Arrange
        input_target = torch.tensor([0, 0, 1, 2])
        expected_triplet_indices = torch.tensor([[0, 1, 2], [0, 1, 3]])
        sut = OnlineTripletLoss(.5)

        # Act
        actual = sut._generate_all_triplets(input_target)

        # Assert
        self.assertSequenceEqual(
            expected_triplet_indices.cpu().numpy().tolist(),
            actual.cpu().numpy().tolist())
Esempio n. 2
0
    def test__generate_all_triplets_three_clases_missing_target(self):
        """
        Case where not all target classes are presented in the target ( can happen within a batch)
        """
        # Arrange
        input_target = torch.tensor([0, 0, 1, 4])
        expected_triplet_indices = torch.tensor([[0, 1, 2], [0, 1, 3]])
        sut = OnlineTripletLoss(.5)

        # Act
        actual = sut._generate_all_triplets(input_target)

        # Assert
        self.assertSequenceEqual(
            expected_triplet_indices.cpu().numpy().tolist(),
            actual.cpu().numpy().tolist())