Ejemplo n.º 1
0
def test_compose():
    a1 = Range.from_string('0, 0:N, 10:20')
    a2 = Range.from_string('0, 0:N, 5:10')

    a_res = Range.from_string('0, 0:N, 15:20')
    assert a_res == a1.compose(a2)

    b1 = Range.from_string('0,0,0:M,0:N')
    b2 = Range.from_string('0,0,0,0:N')

    b_res = Range.from_string('0,0,0,0:N')
    assert b_res == b1.compose(b2)

    c1 = Range.from_string('0, 0:N, 0:M, 50:100')
    c2 = Range.from_string('0, 0, 0, 20:40')
    c3 = Indices.from_string('0 , 0 , 0 , 0')

    c_res1 = Range.from_string('0, 0, 0, 70:90')
    c_res2 = Indices.from_string('0, 0, 0, 50')
    assert c_res1 == c1.compose(c2)
    assert c_res2 == c1.compose(c3)

    d1 = Range.from_string('i,j,0:N')
    d2 = Indices.from_string('0,0,k')

    d_res = Indices.from_string('i,j,k')
    assert d_res == d1.compose(d2)
Ejemplo n.º 2
0
def test_squeeze_unsqueeze_indices():

    a1 = Indices.from_string('i, 0')
    expected_squeezed = [1]
    a2 = deepcopy(a1)
    not_squeezed = a2.squeeze(ignore_indices=[0])
    squeezed = [i for i in range(len(a1)) if i not in not_squeezed]
    unsqueezed = a2.unsqueeze(squeezed)
    assert (squeezed == unsqueezed)
    assert (expected_squeezed == squeezed)
    assert (a1 == a2)

    b1 = Indices.from_string('0, i')
    expected_squeezed = [0]
    b2 = deepcopy(b1)
    not_squeezed = b2.squeeze(ignore_indices=[1])
    squeezed = [i for i in range(len(b1)) if i not in not_squeezed]
    unsqueezed = b2.unsqueeze(squeezed)
    assert (squeezed == unsqueezed)
    assert (expected_squeezed == squeezed)
    assert (b1 == b2)

    c1 = Indices.from_string('i, 0, 0')
    expected_squeezed = [1, 2]
    c2 = deepcopy(c1)
    not_squeezed = c2.squeeze(ignore_indices=[0])
    squeezed = [i for i in range(len(c1)) if i not in not_squeezed]
    unsqueezed = c2.unsqueeze(squeezed)
    assert (squeezed == unsqueezed)
    assert (expected_squeezed == squeezed)
    assert (c1 == c2)

    d1 = Indices.from_string('0, i, 0')
    expected_squeezed = [0, 2]
    d2 = deepcopy(d1)
    not_squeezed = d2.squeeze(ignore_indices=[1])
    squeezed = [i for i in range(len(d1)) if i not in not_squeezed]
    unsqueezed = d2.unsqueeze(squeezed)
    assert (squeezed == unsqueezed)
    assert (expected_squeezed == squeezed)
    assert (d1 == d2)

    e1 = Indices.from_string('0, 0, i')
    expected_squeezed = [0, 1]
    e2 = deepcopy(e1)
    not_squeezed = e2.squeeze(ignore_indices=[2])
    squeezed = [i for i in range(len(e1)) if i not in not_squeezed]
    unsqueezed = e2.unsqueeze(squeezed)
    assert (squeezed == unsqueezed)
    assert (expected_squeezed == squeezed)
    assert (e1 == e2)