def test_easy_balance_balanced(): weights = [1] * 8 P = 4 parts = partition_balanced(weights, P) assert_valid_partition(weights, parts, P) costs = get_partition_weights(weights, parts) assert all(c == 2 for c in costs), costs
def test_balance_bert(): # Parameters per layer for a transformer model with 24 transformers and hidden dim 1024 weights = [ 52559872, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 12596224, 0, 52559872 ] P = 8 parts = partition_balanced(weights, P) assert_valid_partition(weights, parts, P)
def test_int_balanced(): weights = [0, 1, 2, 3, 3, 3] P = 4 parts = partition_balanced(weights, P) assert parts == [0, 3, 4, 5, 6] assert_valid_partition(weights, parts, P) costs = get_partition_weights(weights, parts) assert all(c == 3 for c in costs)
def test_short_partition(): N = 2 P = 4 weights = [1] * N parts = partition_balanced(weights, P) assert_valid_partition(weights, parts, P)
def test_valid_partition(): N = 10 P = 1 weights = [1] * N parts = partition_balanced(weights, P) assert_valid_partition(weights, parts, P)
def test_float_midheavy(): weights = [0., 1.1, 30, 3.] P = 3 parts = partition_balanced(weights, P) assert_valid_partition(weights, parts, P) assert parts == [0, 2, 3, 4]
def test_float_lastheavy(): weights = [0., 1.1, 1.9, 3., 30.] P = 2 parts = partition_balanced(weights, P) assert_valid_partition(weights, parts, P) assert parts == [0, 4, 5]
def test_float_balanced(): weights = [0., 1.1, 1.9, 3., 3., 3.] P = 4 parts = partition_balanced(weights, P) assert_valid_partition(weights, parts, P) assert parts == [0, 3, 4, 5, 6]