def test_value_loss_some_marble_left(self): target_node = self.nodes[0] # Penalty for the closest Marble that did not arrive at the target expected = velocity_weighted_distance(target_node, self.marble_3, pos_weight=self.pos_weight, vel_weight=self.vel_weight) # Penalties for wrong outputs for marble, node in zip([self.marble_1, self.marble_2], self.nodes[1:]): expected += -1/velocity_weighted_distance(node, marble, pos_weight=self.pos_weight, vel_weight=self.vel_weight) torch.testing.assert_allclose(self.loss, expected)
def test_find_min_weighted_distance_to(self): """ Base case: target particle not in Model. Weights differ. """ marble_positions = ( torch.tensor([0.0, 0]), torch.tensor([10.0, 10.0]) ) marble_velocities = ( torch.tensor([7.5, 7.5]), torch.tensor([5.0, 5.0]) ) marbles = [gen_marble_at(pos=pos, vel=vel) for pos, vel in zip(marble_positions, marble_velocities)] model = NenwinModel([], marbles) target = Marble(torch.tensor([15.0, 15.0]), ZERO, ZERO, 0, None, None) pos_weight = 1 vel_weight = 2 expected = velocity_weighted_distance(target, marbles[0], pos_weight, vel_weight) result = find_min_weighted_distance_to(target, model, pos_weight, vel_weight) torch.testing.assert_allclose(result, expected)
def test_value_loss_wrong_pred_no_marble_left(self): """ The loss should equal the velocity-weighted distance of the target Node to the nearest Marble, plus the *negative reciprocal* of the distance of the wrong node (self.nodes[1]) to the Marble. Case where no non-eaten Marble is available. """ target_node = self.nodes[self.target_index] wrong_node = self.nodes[self.wrong_node_index] expected = velocity_weighted_distance(target_node, self.marble, pos_weight=self.pos_weight, vel_weight=self.vel_weight) expected -= 1/velocity_weighted_distance(wrong_node, self.marble, pos_weight=self.pos_weight, vel_weight=self.vel_weight) torch.testing.assert_allclose(self.loss, expected)
def test_value_loss_no_output(self): """ If no Marble has been eaten, the loss should equal the velocity-weighted distance of the target Node to the nearest Marble. """ expected = velocity_weighted_distance(self.node, self.marble, pos_weight=self.pos_weight, vel_weight=self.vel_weight) torch.testing.assert_allclose(self.loss, expected)
def test_value_loss_wrong_pred_some_marble_left(self): """ The loss should equal the velocity-weighted distance of the target Node to the nearest Marble, plus the *negative reciprocal* of the distance of the wrong node (self.nodes[1]) to the Marble. Case where another non-eaten Marble is still available at time of loss computation. """ second_marble = gen_marble_at(ZERO, datum="second") self.model.add_marbles([second_marble]) self.loss = self.loss_fun(self.target_index) target_node = self.nodes[self.target_index] wrong_node = self.nodes[self.wrong_node_index] expected = velocity_weighted_distance(target_node, second_marble, pos_weight=self.pos_weight, vel_weight=self.vel_weight) expected -= 1/velocity_weighted_distance(wrong_node, self.marble, pos_weight=self.pos_weight, vel_weight=self.vel_weight) torch.testing.assert_allclose(self.loss, expected)
def test_velocity_weighted_distance_1(self): """ Base case: 0 distance, both weights 1. """ pos_1 = torch.tensor([15.0, 15.0]) m1 = Marble(pos_1, ZERO, ZERO, 0, None, None) pos_2 = torch.tensor([0.0, 0.0]) vel_2 = torch.tensor([15.0, 15.0]) m2 = Marble(pos_2, vel_2, ZERO, 0, None, None) expected = torch.tensor(0.0) result = velocity_weighted_distance(m1, m2, 1, 1) torch.testing.assert_allclose(result, expected)
def test_velocity_weighted_distance_3(self): """ Corner case: vel weight zero. """ pos_1 = torch.tensor([123, 15.0]) m1 = Marble(pos_1, ZERO, ZERO, 0, None, None) pos_2 = torch.tensor([4.3, 2.4]) vel_2 = torch.tensor([89.0, 30.0]) m2 = Marble(pos_2, vel_2, ZERO, 0, None, None) expected = distance(m1, m2)**2 pos_weight = 1 vel_weight = 0 result = velocity_weighted_distance(m1, m2, pos_weight, vel_weight) torch.testing.assert_allclose(result, expected)
def test_velocity_weighted_distance_3(self): """ Corner case: pos weight zero. """ pos_1 = torch.tensor([0.0, 15.0]) m1 = Marble(pos_1, ZERO, ZERO, 0, None, None) pos_2 = torch.tensor([0.0, 0.0]) vel_2 = torch.tensor([0.0, 30.0]) m2 = Marble(pos_2, vel_2, ZERO, 0, None, None) expected = torch.tensor(100.0) pos_weight = 0 vel_weight = 1/3 result = velocity_weighted_distance(m1, m2, pos_weight, vel_weight) torch.testing.assert_allclose(result, expected)
def test_value_loss_no_marble_left(self): """ For the two wrong Marbles (M_1 and M_2), their loss should equal the velocity-weighted distance of the target Node to the nearest Marble, plus the *negative reciprocal* of the distance of the wrong node (self.nodes[1]) to the Marble. For the correcly arrived Marble, the loss should equal 0. Case where no non-eaten Marble is available at time of loss computation. Not that this should matter, as a correct Marble was eaten! """ wrong_node = self.nodes[self.wrong_node_index] expected = 0.0 for marble in [self.marble_1, self.marble_2]: expected -= 1/velocity_weighted_distance(wrong_node, marble, pos_weight=self.pos_weight, vel_weight=self.vel_weight) torch.testing.assert_allclose(self.loss, expected)