Ejemplo n.º 1
0
    def test__get_distance_single(self):
        """
        Case where just a single item in each array
        """
        # Arrange
        input_x = torch.tensor([[1, 8, 7]], dtype=torch.float)
        input_y = torch.tensor([[2, 3, 4]], dtype=torch.float)

        expected = torch.tensor([35])
        sut = OnlineTripletLoss(.5)

        # Act
        actual = sut._get_distance(input_x, input_y)

        # Assert
        self.assertSequenceEqual(expected.cpu().numpy().round(2).tolist(),
                                 actual.cpu().numpy().round(2).tolist())
Ejemplo n.º 2
0
    def test__get_distance_zero(self):
        """
        Case where just a single item in each array that are the same
        """
        # Arrange
        input_x = torch.tensor([[0, 0, 1]], dtype=torch.float)
        input_y = torch.tensor([[0, 0, 1]], dtype=torch.float)

        expected = torch.tensor([0])
        sut = OnlineTripletLoss(.5)

        # Act
        actual = sut._get_distance(input_x, input_y)

        # Assert
        self.assertSequenceEqual(expected.cpu().numpy().tolist(),
                                 actual.cpu().numpy().tolist())