Example #1
0
    def test_value_loss_some_marble_left(self):
        target_node = self.nodes[0]

        # Penalty for the closest Marble that did not arrive at the target
        expected = velocity_weighted_distance(target_node, self.marble_3,
                                              pos_weight=self.pos_weight,
                                              vel_weight=self.vel_weight)
        # Penalties for wrong outputs
        for marble, node in zip([self.marble_1, self.marble_2], self.nodes[1:]):
            expected += -1/velocity_weighted_distance(node, marble,
                                                    pos_weight=self.pos_weight,
                                                    vel_weight=self.vel_weight)
        torch.testing.assert_allclose(self.loss, expected)
Example #2
0
    def test_find_min_weighted_distance_to(self):
        """
        Base case: target particle not in Model. Weights differ.
        """
        marble_positions = (
            torch.tensor([0.0, 0]),
            torch.tensor([10.0, 10.0])
        )

        marble_velocities = (
            torch.tensor([7.5, 7.5]),
            torch.tensor([5.0, 5.0])
        )

        marbles = [gen_marble_at(pos=pos, vel=vel)
                   for pos, vel in zip(marble_positions, marble_velocities)]
        model = NenwinModel([], marbles)

        target = Marble(torch.tensor([15.0, 15.0]), ZERO, ZERO, 0, None, None)

        pos_weight = 1
        vel_weight = 2

        expected = velocity_weighted_distance(target, marbles[0],
                                              pos_weight, vel_weight)
        result = find_min_weighted_distance_to(target, model,
                                               pos_weight, vel_weight)
        torch.testing.assert_allclose(result, expected)
Example #3
0
    def test_value_loss_wrong_pred_no_marble_left(self):
        """
        The loss should equal the velocity-weighted distance
        of the target Node to the nearest Marble,
        plus the *negative reciprocal* of the distance of the wrong node
        (self.nodes[1]) to the Marble.

        Case where no non-eaten Marble is available.
        """

        target_node = self.nodes[self.target_index]
        wrong_node = self.nodes[self.wrong_node_index]
        expected = velocity_weighted_distance(target_node, self.marble,
                                              pos_weight=self.pos_weight,
                                              vel_weight=self.vel_weight)
        expected -= 1/velocity_weighted_distance(wrong_node, self.marble,
                                                 pos_weight=self.pos_weight,
                                                 vel_weight=self.vel_weight)
        torch.testing.assert_allclose(self.loss, expected)
Example #4
0
 def test_value_loss_no_output(self):
     """
     If no Marble has been eaten, the loss
     should equal the velocity-weighted distance
     of the target Node to the nearest Marble.
     """
     expected = velocity_weighted_distance(self.node, self.marble,
                                           pos_weight=self.pos_weight,
                                           vel_weight=self.vel_weight)
     torch.testing.assert_allclose(self.loss, expected)
Example #5
0
    def test_value_loss_wrong_pred_some_marble_left(self):
        """
        The loss should equal the velocity-weighted distance
        of the target Node to the nearest Marble,
        plus the *negative reciprocal* of the distance of the wrong node
        (self.nodes[1]) to the Marble.

        Case where another non-eaten Marble is still available
        at time of loss computation.
        """
        second_marble = gen_marble_at(ZERO, datum="second")
        self.model.add_marbles([second_marble])
        self.loss = self.loss_fun(self.target_index)

        target_node = self.nodes[self.target_index]
        wrong_node = self.nodes[self.wrong_node_index]
        expected = velocity_weighted_distance(target_node, second_marble,
                                              pos_weight=self.pos_weight,
                                              vel_weight=self.vel_weight)
        expected -= 1/velocity_weighted_distance(wrong_node, self.marble,
                                                 pos_weight=self.pos_weight,
                                                 vel_weight=self.vel_weight)
        torch.testing.assert_allclose(self.loss, expected)
Example #6
0
    def test_velocity_weighted_distance_1(self):
        """
        Base case: 0 distance, both weights 1.
        """
        pos_1 = torch.tensor([15.0, 15.0])
        m1 = Marble(pos_1, ZERO, ZERO, 0, None, None)

        pos_2 = torch.tensor([0.0, 0.0])
        vel_2 = torch.tensor([15.0, 15.0])
        m2 = Marble(pos_2, vel_2, ZERO, 0, None, None)

        expected = torch.tensor(0.0)
        result = velocity_weighted_distance(m1, m2, 1, 1)

        torch.testing.assert_allclose(result, expected)
Example #7
0
    def test_velocity_weighted_distance_3(self):
        """
        Corner case: vel weight zero.
        """
        pos_1 = torch.tensor([123, 15.0])
        m1 = Marble(pos_1, ZERO, ZERO, 0, None, None)

        pos_2 = torch.tensor([4.3, 2.4])
        vel_2 = torch.tensor([89.0, 30.0])
        m2 = Marble(pos_2, vel_2, ZERO, 0, None, None)

        expected = distance(m1, m2)**2
        pos_weight = 1
        vel_weight = 0
        result = velocity_weighted_distance(m1, m2, pos_weight, vel_weight)

        torch.testing.assert_allclose(result, expected)
Example #8
0
    def test_velocity_weighted_distance_3(self):
        """
        Corner case: pos weight zero.
        """
        pos_1 = torch.tensor([0.0, 15.0])
        m1 = Marble(pos_1, ZERO, ZERO, 0, None, None)

        pos_2 = torch.tensor([0.0, 0.0])
        vel_2 = torch.tensor([0.0, 30.0])
        m2 = Marble(pos_2, vel_2, ZERO, 0, None, None)

        expected = torch.tensor(100.0)
        pos_weight = 0
        vel_weight = 1/3
        result = velocity_weighted_distance(m1, m2, pos_weight, vel_weight)

        torch.testing.assert_allclose(result, expected)
Example #9
0
    def test_value_loss_no_marble_left(self):
        """
        For the two wrong Marbles (M_1 and M_2), 
        their loss should equal the velocity-weighted distance
        of the target Node to the nearest Marble,
        plus the *negative reciprocal* of the distance of the wrong node
        (self.nodes[1]) to the Marble.

        For the correcly arrived Marble, the loss should equal 0.
        
        Case where no non-eaten Marble is available 
        at time of loss computation. 
        Not that this should matter, as a correct Marble was eaten!
        """
        wrong_node = self.nodes[self.wrong_node_index]

        expected = 0.0
        for marble in [self.marble_1, self.marble_2]:
            expected -= 1/velocity_weighted_distance(wrong_node, marble,
                                                    pos_weight=self.pos_weight,
                                                    vel_weight=self.vel_weight)
        torch.testing.assert_allclose(self.loss, expected)