Ejemplo n.º 1
0
def test_e2e():
    hist = History([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5],
                    [torch.tensor(2), 2, 5]],
                   [[torch.tensor(1), 0, 1], [torch.tensor(1), 1, 5]])
    desired_dict_lhs = defaultdict(list)
    desired_dict_lhs[0] = [[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]]
    desired_dict_lhs[2] = [[torch.tensor(2), 2, 5]]

    desired_dict_rhs = defaultdict(list)
    desired_dict_rhs[0] = [[torch.tensor(1), 0, 1]]
    desired_dict_rhs[1] = [[torch.tensor(1), 1, 5]]

    currpool = CurrentElems([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]],
                            [[torch.tensor(1), 0, 1]])

    e_weights_type = toy_e_weights_type()
    potentials = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0])
    match_edges, e_weights_full = compute_matching(currpool, potentials,
                                                   e_weights_type)

    result_pool, total_loss = step_simulation(currpool, match_edges,
                                              e_weights_full, desired_dict_lhs,
                                              desired_dict_rhs, 1)

    assert approx(total_loss, 0.1)

    assert result_pool.lhs == [[torch.tensor(2), 0, 5]]
    assert result_pool.rhs == [[torch.tensor(1), 1, 5]]
Ejemplo n.º 2
0
def test_step_simulation():
    hist = History([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5],
                    [torch.tensor(2), 2, 5]],
                   [[torch.tensor(1), 0, 1], [torch.tensor(1), 1, 5]])
    desired_dict_lhs = defaultdict(list)
    desired_dict_lhs[0] = [[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]]
    desired_dict_lhs[2] = [[torch.tensor(2), 2, 5]]

    desired_dict_rhs = defaultdict(list)
    desired_dict_rhs[0] = [[torch.tensor(1), 0, 1]]
    desired_dict_rhs[1] = [[torch.tensor(1), 1, 5]]

    currpool = CurrentElems([[torch.tensor(1), 0, 2], [torch.tensor(2), 0, 5]],
                            [[torch.tensor(1), 0, 1]])
    match_edges = torch.tensor([[0.0], [1.0]])
    e_weights_type = toy_e_weights_type()
    e_weights = weight_matrix(currpool.lhs, currpool.rhs, e_weights_type)

    result_pool, total_loss = step_simulation(currpool, match_edges, e_weights,
                                              desired_dict_lhs,
                                              desired_dict_rhs, 1)

    assert approx(total_loss, 0.1)

    assert result_pool.lhs == [[torch.tensor(1), 0, 2]]
    assert result_pool.rhs == [[torch.tensor(1), 1, 5]]
Ejemplo n.º 3
0
def unambiguous_matching():
    currpool = CurrentElems([[torch.tensor(2), 0, 5], [torch.tensor(1), 0, 5], [torch.tensor(2), 0, 5]],
                            [[torch.tensor(0), 0, 5]])
    e_weights = toy_e_weights_type()
    correct_matching = torch.tensor([[0.0],
                                     [1.0],
                                     [0.0]])
    return currpool, e_weights, correct_matching
Ejemplo n.º 4
0
def test_weight_matrix():
    currpool = CurrentElems([[torch.tensor(2), 0, 5], [torch.tensor(1), 0, 5], [torch.tensor(2), 0, 5]],
                            [[torch.tensor(0), 0, 5]])
    e_weights = toy_e_weights_type()
    result_weights = weight_matrix(currpool.lhs, currpool.rhs, e_weights)
    assert torch.allclose(result_weights, torch.tensor([[-100.0],
                           [3.0],
                           [-100.0]]))
Ejemplo n.º 5
0
def test_zero_potentials():
    currpool = CurrentElems([[torch.tensor(1), 0, 5], [torch.tensor(2), 0,5]],
                            [[torch.tensor(2), 0, 5]])
    e_weights = toy_e_weights_type()
    potentials = torch.tensor([0.0,0.0,0.0,0.0,0.0])

    desired_match = torch.tensor([[1.0],
                                  [0.0]])

    result_match, e_weights = compute_matching(currpool, potentials, e_weights)
    assert torch.allclose(result_match, desired_match, atol=1e-6)
Ejemplo n.º 6
0
def dont_test_tiebreak():
    # this is a failing test that reveals a fractional matching
    currpool = CurrentElems([[torch.tensor(1), 0, 5],[torch.tensor(1), 0, 5], [torch.tensor(1), 0, 5]],
                            [[torch.tensor(1), 0, 5], [torch.tensor(1), 0, 5]])

    e_weights = toy_e_weights_type()
    e_weights_full = weight_matrix(currpool.lhs, currpool.rhs, e_weights)

    desired_match = torch.tensor([[1.0,0.0],
                                  [0.0,1.0],
                                  [0.0,0.0]])

    result_match, e_weights = compute_matching(currpool, torch.zeros(5), e_weights)
    assert torch.allclose(result_match, desired_match, atol=1e-6)