def test_str_dsa_class(): variable = Variable('a', [0, 1, 2, 3, 4]) c1 = UnaryFunctionRelation('c1', variable, lambda x: abs(x - 2)) computation = DsaComputation( ComputationDef(VariableComputationNode(variable, [c1]), AlgorithmDef.build_with_default_param('dsa'))) assert str(computation) == "dsa.DsaComputation(a)"
def test_findargmin_several_values(self): v1 = Variable("v1", list(range(10))) f1 = UnaryFunctionRelation("f1", v1, lambda x: 2 if 3 < x < 6 else 10) values, c = pydcop.dcop.relations.find_arg_optimal(v1, f1, mode="min") self.assertEqual(len(values), 2) self.assertIn(4, values) self.assertIn(5, values) self.assertEqual(c, 2)
def test_findargmin_fct(self): v1 = Variable("v1", list(range(10))) f1 = UnaryFunctionRelation("f1", v1, lambda x: abs(x - 5)) m, c = pydcop.dcop.relations.find_arg_optimal(v1, f1, mode="min") self.assertEqual(len(m), 1) self.assertEqual(m[0], 5) self.assertEqual(c, 0)
def test_1_unary_constraint_means_no_neighbors(): variable = Variable('a', [0, 1, 2, 3, 4]) c1 = UnaryFunctionRelation('c1', variable, lambda x: abs(x - 2)) node = VariableComputationNode(variable, [c1]) comp_def = ComputationDef(node, AlgorithmDef.build_with_default_param('dsa')) computation = DsaComputation(comp_def=comp_def) assert len(computation.neighbors) == 0
def test_1_unary_constraint(self): variable = Variable('a', [0, 1, 2, 3, 4]) c1 = UnaryFunctionRelation('c1', variable, lambda x: abs(x - 2)) computation = DsaComputation(variable, [c1], comp_def=MagicMock()) val, sum_costs = computation._compute_best_value() self.assertEqual(val, [2]) self.assertEqual(sum_costs, 0)
def test_findargmin_several_values(self): v1 = Variable('v1', list(range(10))) f1 = UnaryFunctionRelation('f1', v1, lambda x: 2 if 3 < x < 6 else 10) values, c = algorithms.find_arg_optimal(v1, f1, mode='min') self.assertEqual(len(values), 2) self.assertIn(4, values) self.assertIn(5, values) self.assertEqual(c, 2)
def test_findargmin_fct(self): v1 = Variable('v1', list(range(10))) f1 = UnaryFunctionRelation('f1', v1, lambda x: abs(x-5)) m, c = algorithms.find_arg_optimal(v1, f1, mode='min') self.assertEqual(len(m), 1) self.assertEqual(m[0], 5) self.assertEqual(c, 0)
def test_best_unary(self): x = Variable("x", list(range(5))) phi = UnaryFunctionRelation("phi", x, lambda x_: 1 if x_ in [0, 2, 3] else 0) computation = Mgm2Computation(x, [phi], comp_def=MagicMock()) computation.__value__ = 0 bests, best = computation._compute_best_value() self.assertEqual(best, 0) self.assertEqual(bests, [1, 4])
def graph_coloring_pb(): # Variables and domain d1 = [0, 1] d2 = [0, 1, 2] x1 = Variable('x1', d1) x2 = Variable('x2', d2) # Cost functions for x1 and x2 x1_cost = UnaryFunctionRelation('x1_cost', x1, lambda v: {0: 0, 1: -3}[v]) x2_cost = UnaryFunctionRelation('x2_cost', x2, lambda v: { 0: 0, 1: -2, 2: -1 }[v]) # Constraint x1 != x2 # Without any cost @relations.AsNAryFunctionRelation(x1, x2) def all_diff(x1_val, x2_val): if x1_val == x2_val: return 10000 return 0 # Map the factor graph to agents variables = [x1, x2] factors = [x1_cost, x2_cost, all_diff] node_agents = distribue_agent_for_all(variables, factors) # and solve it results, _, _ = synchronous_single_run(node_agents) print(results) if results['x1'] == 1 and results['x2'] == 2: logging.info('SUCCESS !! ') return 0 else: logging.info('invalid result found, needs some debugging ...' + str(results)) return 1
def test_1var_1rel(self): domain = list(range(10)) l1 = Variable('l1', domain) rel_l1 = UnaryFunctionRelation('rel_l1', l1, lambda x: x) nodes = as_bipartite_graph([l1], [rel_l1]) self.assertEqual(len(nodes), 2) var_nodes = [n for n in nodes if n.type == 'VARIABLE'] rel_nodes = [n for n in nodes if n.type == 'CONSTRAINT'] self.assertEqual(len(var_nodes), 1) self.assertEqual(len(rel_nodes), 1)
def test_unary_function_relation(self): x = Variable("x", list(range(5))) # x2 = Variable('x2', list(range(5))) # @AsNAryFunctionRelation(x, x2) # def phi(x1_): # return x1_ phi = UnaryFunctionRelation("phi", x, lambda x_: 1 if x_ in [0, 2, 3] else 0) computation = Mgm2Computation(x, [phi], comp_def=MagicMock()) computation.__value__ = 0 self.assertEqual(computation._compute_cost({'x': 0}), 1)
def test_best_unary(self): x = Variable("x", list(range(5))) phi = UnaryFunctionRelation("phi", x, lambda x_: 1 if x_ in [0, 2, 3] else 0) computation = Mgm2Computation( ComputationDef( VariableComputationNode(x, [phi]), AlgorithmDef.build_with_default_param("mgm2"), )) computation.__value__ = 0 bests, best = computation._compute_best_value() self.assertEqual(best, 0) self.assertEqual(bests, [1, 4])
def test_current_local_cost_unary(self): x = Variable("x", list(range(5))) # x2 = Variable('x2', list(range(5))) # @AsNAryFunctionRelation(x, x2) # def phi(x1_): # return x1_ phi = UnaryFunctionRelation("phi", x, lambda x_: 1 if x_ in [0, 2, 3] else 0) computation = Mgm2Computation(x, [phi], comp_def=MagicMock()) computation.__value__ = 0 computation2 = Mgm2Computation(x, [phi], comp_def=MagicMock()) computation2.__value__ = 1 self.assertEqual(computation._current_local_cost(), 1) self.assertEqual(computation2._current_local_cost(), 0)
def test_unary_function_relation(self): x = Variable("x", list(range(5))) # x2 = Variable('x2', list(range(5))) # @AsNAryFunctionRelation(x, x2) # def phi(x1_): # return x1_ phi = UnaryFunctionRelation("phi", x, lambda x_: 1 if x_ in [0, 2, 3] else 0) computation = Mgm2Computation( ComputationDef( VariableComputationNode(x, [phi]), AlgorithmDef.build_with_default_param("mgm2"), )) computation.__value__ = 0 self.assertEqual(computation._compute_cost(**{"x": 0}), 1)
def test_1_unary_constraint_means_no_neighbors(self): variable = Variable('a', [0, 1, 2, 3, 4]) c1 = UnaryFunctionRelation('c1', variable, lambda x: abs(x - 2)) computation = DsaComputation(variable, [c1], comp_def=MagicMock()) self.assertEqual(len(computation._neighbors), 0)
class TestsConstraintViolation(unittest.TestCase): domain = list(range(2)) x1 = Variable('x1', domain) x2 = Variable('x2', domain) x3 = Variable('x3', domain) phi = UnaryFunctionRelation('phi', Variable('x1', domain), lambda x: x) phi_n_ary = NAryFunctionRelation( lambda x1_, x2_, x3_: 2 if x1_ == x2_ else (1 if x1_ == x3_ else 0), [x1, x2, x3]) def NZ_violation_unary(self): g = GdbaComputation(self.x1, [self.phi], comp_def=MagicMock()) g._neighbors_values['x2'] = 1 g._neighbors_values['x3'] = 2 c = g.__constraints__[0] self.assertEqual(g._is_violated(c, 0), False) self.assertEqual(g._is_violated(c, 1), True) self.assertEqual(g._is_violated(c, 2), True) def NZ_violation_n_ary(self): g = GdbaComputation(self.x1, [self.phi_n_ary], comp_def=MagicMock()) g._neighbors_values['x2'] = 1 g._neighbors_values['x3'] = 2 c = g.__constraints__[0] self.assertEqual(g._is_violated(c, 0), False) self.assertEqual(g._is_violated(c, 1), True) self.assertEqual(g._is_violated(c, 2), True) def NM_violation_unary(self): g = GdbaComputation(self.x1, [self.phi], comp_def=MagicMock()) g._neighbors_values['x2'] = 1 g._neighbors_values['x3'] = 2 g._violation_mode = 'NM' c = g.__constraints__[0] self.assertEqual(g._is_violated(c, 0), False) self.assertEqual(g._is_violated(c, 1), True) self.assertEqual(g._is_violated(c, 2), True) def NM_violation_n_ary(self): g = GdbaComputation(self.x1, [self.phi_n_ary], comp_def=MagicMock()) g._neighbors_values['x2'] = 1 g._neighbors_values['x3'] = 2 g._violation_mode = 'NM' c = g.__constraints__[0] self.assertEqual(g._is_violated(c, 0), False) self.assertEqual(g._is_violated(c, 1), True) self.assertEqual(g._is_violated(c, 2), True) def MX_violation_unary(self): g = GdbaComputation(self.x1, [self.phi], comp_def=MagicMock()) g._neighbors_values['x2'] = 1 g._neighbors_values['x3'] = 2 g._violation_mode = 'MX' c = g.__constraints__[0] self.assertEqual(g._is_violated(c, 0), False) self.assertEqual(g._is_violated(c, 1), False) self.assertEqual(g._is_violated(c, 2), True) def MX_violation_n_ary(self): g = GdbaComputation(self.x1, [self.phi_n_ary], comp_def=MagicMock()) g._neighbors_values['x2'] = 1 g._neighbors_values['x3'] = 2 g._violation_mode = 'MX' c = g.__constraints__[0] self.assertEqual(g._is_violated(c, 0), False) self.assertEqual(g._is_violated(c, 1), True) self.assertEqual(g._is_violated(c, 2), False)