def test_split(weight, split_size_or_sections, number_of_splits):
    """keep_size=False"""
    op = Split(True, split_size_or_sections, number_of_splits, keep_size=False)
    s = op(weight)
    assert all(len(x) == 5 for x in s)

    op.set_input_indices((weight, ))
    s = op(torch.rand(12))
    assert all(len(x) == 4 for x in s)
def test_split_2(weight, split_size_or_sections, number_of_splits):
    """keep_size=True"""
    op = Split(True, split_size_or_sections, number_of_splits, keep_size=True)
    s = op(weight)
    assert all(len(x) == 5 for x in s)

    op.set_input_indices((weight, ))
    s = op(torch.rand(12))
    assert all(len(x) == 5 for x in s)

    # keep_size=True expands the input with zeros
    location_of_zeros_in_splits = [4, 2, 2]
    for x, i in zip(s, location_of_zeros_in_splits):
        (idx, ) = torch.where(x == 0)
        assert idx == torch.tensor([i])
def test_split_parameter_check(weight):
    with pytest.raises(AssertionError):
        Split(enable_pruning=True,
              split_size_or_sections=None,
              number_of_splits=None)
def test_split_no_enable_pruning(weight, split_size_or_sections):
    op = Split(enable_pruning=False, keep_size=False)
    s = op(weight, split_size_or_sections)
    assert all(len(x) == 5 for x in s)