Ejemplo n.º 1
0
def test_easy_balance_balanced():
    weights = [1] * 8
    P = 4
    parts = partition_balanced(weights, P)
    assert_valid_partition(weights, parts, P)
    costs = get_partition_weights(weights, parts)
    assert all(c == 2 for c in costs), costs
Ejemplo n.º 2
0
def test_balance_bert():
    # Parameters per layer for a transformer model with 24 transformers and hidden dim 1024
    weights = [
        52559872,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        12596224,
        0,
        52559872
    ]
    P = 8
    parts = partition_balanced(weights, P)
    assert_valid_partition(weights, parts, P)
Ejemplo n.º 3
0
def test_int_balanced():
    weights = [0, 1, 2, 3, 3, 3]
    P = 4
    parts = partition_balanced(weights, P)
    assert parts == [0, 3, 4, 5, 6]

    assert_valid_partition(weights, parts, P)
    costs = get_partition_weights(weights, parts)
    assert all(c == 3 for c in costs)
Ejemplo n.º 4
0
def test_short_partition():
    N = 2
    P = 4
    weights = [1] * N
    parts = partition_balanced(weights, P)
    assert_valid_partition(weights, parts, P)
Ejemplo n.º 5
0
def test_valid_partition():
    N = 10
    P = 1
    weights = [1] * N
    parts = partition_balanced(weights, P)
    assert_valid_partition(weights, parts, P)
Ejemplo n.º 6
0
def test_float_midheavy():
    weights = [0., 1.1, 30, 3.]
    P = 3
    parts = partition_balanced(weights, P)
    assert_valid_partition(weights, parts, P)
    assert parts == [0, 2, 3, 4]
Ejemplo n.º 7
0
def test_float_lastheavy():
    weights = [0., 1.1, 1.9, 3., 30.]
    P = 2
    parts = partition_balanced(weights, P)
    assert_valid_partition(weights, parts, P)
    assert parts == [0, 4, 5]
Ejemplo n.º 8
0
def test_float_balanced():
    weights = [0., 1.1, 1.9, 3., 3., 3.]
    P = 4
    parts = partition_balanced(weights, P)
    assert_valid_partition(weights, parts, P)
    assert parts == [0, 3, 4, 5, 6]