def test_get_max_marginalized(self): init_factor = Factor(self.random_variables) init_factor.add_value([1, 3, 2], 10).add_value([1, 4, 2], 5).add_value([1, 5, 2], 2) init_factor.add_value([2, 4, 2], 20).add_value([2, 5, 2], 30) factor, assignment = init_factor.get_max_marginalized( [self.random_variables[0]]) marginalized_variables = [ self.random_variables[1], self.random_variables[2] ] self.assertEqual(20, factor.get_value({'y': 4, 'z': 2})) self.assertEqual(30, factor.get_value({'y': 5, 'z': 2})) self.assertEqual(10, factor.get_value({'y': 3, 'z': 2})) variable_permutation = list( map(lambda x: factor.random_variables.index(x), marginalized_variables)) mapped_assignment = {} for key, value in assignment.items(): mapped_key = tuple( map(lambda x: key[variable_permutation[x]], range(len(key)))) mapped_assignment[mapped_key] = value self.assertEqual({ (3, 2): (1, ), (4, 2): (2, ), (5, 2): (2, ) }, mapped_assignment) factor, assignment = init_factor.get_max_marginalized( self.random_variables) self.assertEqual(30, factor.get_value([])) self.assertEqual({(): (2, 5, 2)}, assignment)