def test_e2e(): hist = History([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5], [torch.tensor(2), 2, 5]], [[torch.tensor(1), 0, 1], [torch.tensor(1), 1, 5]]) desired_dict_lhs = defaultdict(list) desired_dict_lhs[0] = [[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]] desired_dict_lhs[2] = [[torch.tensor(2), 2, 5]] desired_dict_rhs = defaultdict(list) desired_dict_rhs[0] = [[torch.tensor(1), 0, 1]] desired_dict_rhs[1] = [[torch.tensor(1), 1, 5]] currpool = CurrentElems([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]], [[torch.tensor(1), 0, 1]]) e_weights_type = toy_e_weights_type() potentials = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]) match_edges, e_weights_full = compute_matching(currpool, potentials, e_weights_type) result_pool, total_loss = step_simulation(currpool, match_edges, e_weights_full, desired_dict_lhs, desired_dict_rhs, 1) assert approx(total_loss, 0.1) assert result_pool.lhs == [[torch.tensor(2), 0, 5]] assert result_pool.rhs == [[torch.tensor(1), 1, 5]]
def test_step_simulation(): hist = History([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5], [torch.tensor(2), 2, 5]], [[torch.tensor(1), 0, 1], [torch.tensor(1), 1, 5]]) desired_dict_lhs = defaultdict(list) desired_dict_lhs[0] = [[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]] desired_dict_lhs[2] = [[torch.tensor(2), 2, 5]] desired_dict_rhs = defaultdict(list) desired_dict_rhs[0] = [[torch.tensor(1), 0, 1]] desired_dict_rhs[1] = [[torch.tensor(1), 1, 5]] currpool = CurrentElems([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]], [[torch.tensor(1), 0, 1]]) match_edges = torch.tensor([[0.0], [1.0]]) e_weights_type = toy_e_weights_type() e_weights = weight_matrix(currpool.lhs, currpool.rhs, e_weights_type) result_pool, total_loss = step_simulation(currpool, match_edges, e_weights, desired_dict_lhs, desired_dict_rhs, 1) assert approx(total_loss, 0.1) assert result_pool.lhs == [[torch.tensor(1), 0, 2]] assert result_pool.rhs == [[torch.tensor(1), 1, 5]]
def unambiguous_matching(): currpool = CurrentElems([[torch.tensor(2), 0, 5], [torch.tensor(1), 0, 5], [torch.tensor(2), 0, 5]], [[torch.tensor(0), 0, 5]]) e_weights = toy_e_weights_type() correct_matching = torch.tensor([[0.0], [1.0], [0.0]]) return currpool, e_weights, correct_matching
def test_weight_matrix(): currpool = CurrentElems([[torch.tensor(2), 0, 5], [torch.tensor(1), 0, 5], [torch.tensor(2), 0, 5]], [[torch.tensor(0), 0, 5]]) e_weights = toy_e_weights_type() result_weights = weight_matrix(currpool.lhs, currpool.rhs, e_weights) assert torch.allclose(result_weights, torch.tensor([[-100.0], [3.0], [-100.0]]))
def test_zero_potentials(): currpool = CurrentElems([[torch.tensor(1), 0, 5], [torch.tensor(2), 0,5]], [[torch.tensor(2), 0, 5]]) e_weights = toy_e_weights_type() potentials = torch.tensor([0.0,0.0,0.0,0.0,0.0]) desired_match = torch.tensor([[1.0], [0.0]]) result_match, e_weights = compute_matching(currpool, potentials, e_weights) assert torch.allclose(result_match, desired_match, atol=1e-6)
def test_arrivals_only(): hist = History([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5], [torch.tensor(2), 2, 5]], [[torch.tensor(1), 0, 1], [torch.tensor(1), 1, 5]]) desired_dict_lhs = defaultdict(list) desired_dict_lhs[0] = [[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]] desired_dict_lhs[2] = [[torch.tensor(2), 2, 5]] desired_dict_rhs = defaultdict(list) desired_dict_rhs[0] = [[torch.tensor(1), 0, 1]] desired_dict_rhs[1] = [[torch.tensor(1), 1, 5]] currpool = CurrentElems([], []) newpool = arrivals_only(currpool, desired_dict_lhs, desired_dict_rhs, 0) target_pool = CurrentElems( [[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]], [[torch.tensor(1), 0, 1]]) assert newpool.lhs == target_pool.lhs assert newpool.rhs == target_pool.rhs
def dont_test_tiebreak(): # this is a failing test that reveals a fractional matching currpool = CurrentElems([[torch.tensor(1), 0, 5],[torch.tensor(1), 0, 5], [torch.tensor(1), 0, 5]], [[torch.tensor(1), 0, 5], [torch.tensor(1), 0, 5]]) e_weights = toy_e_weights_type() e_weights_full = weight_matrix(currpool.lhs, currpool.rhs, e_weights) desired_match = torch.tensor([[1.0,0.0], [0.0,1.0], [0.0,0.0]]) result_match, e_weights = compute_matching(currpool, torch.zeros(5), e_weights) assert torch.allclose(result_match, desired_match, atol=1e-6)