Exemple #1
0
def test_batch_cat_and_stack():
    # test cat with compatible keys
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat([b1, b2])
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1

    a = Batch(a=Batch(a=np.random.randn(3, 4)))
    assert np.allclose(
        np.concatenate([a.a.a, a.a.a]),
        Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)

    # test cat with lens infer
    a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
    b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
    ans = Batch.cat([a, b, a])
    assert np.allclose(ans.a.a,
                       np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
    assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
    assert ans.a.t.is_empty()

    assert b1.stack_([b2]) is None
    assert isinstance(b1.a.d.e, np.ndarray)
    assert b1.a.d.e.ndim == 2

    # test cat with incompatible keys
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with reserved keys (values are Batch())
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with all reserved keys (values are Batch())
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=Batch(),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert ans.a.is_empty()
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test stack with compatible keys
    b3 = Batch(a=np.zeros((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
    b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
                        {'a': True, 'b': {'c': 3.0}}])
    b5 = Batch(b5_dict)
    assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
    assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
    assert b5.b.d[0] == b5_dict[0]['b']['d']
    assert b5.b.d[1] == 0.0

    # test stack with incompatible keys
    a = Batch(a=1, b=2, c=3)
    b = Batch(a=4, b=5, d=6)
    c = Batch(c=7, b=6, d=9)
    d = Batch.stack([a, b, c])
    assert np.allclose(d.a, [1, 4, 0])
    assert np.allclose(d.b, [2, 5, 6])
    assert np.allclose(d.c, [3, 0, 7])
    assert np.allclose(d.d, [0, 6, 9])

    # test stack with empty Batch()
    assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
    a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
    b = Batch(a=4, b=5, d=6, e=Batch())
    c = Batch(c=7, b=6, d=9, e=Batch())
    d = Batch.stack([a, b, c])
    assert np.allclose(d.a, [1, 4, 0])
    assert np.allclose(d.b, [2, 5, 6])
    assert np.allclose(d.c, [3, 0, 7])
    assert np.allclose(d.d, [0, 6, 9])
    assert d.e.is_empty()
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
    b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.stack([b1, b2], axis=-1)
    assert test.a.is_empty()
    assert test.b.is_empty()
    assert np.allclose(test.common.c,
                       np.stack([b1.common.c, b2.common.c], axis=-1))

    b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
    b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.stack([b1, b2])
    ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
                b=torch.stack([torch.zeros(4, 6), b2.b]),
                common=Batch(c=np.stack([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test with illegal input format
    with pytest.raises(ValueError):
        Batch.cat([[Batch(a=1)], [Batch(a=1)]])
    with pytest.raises(ValueError):
        Batch.stack([[Batch(a=1)], [Batch(a=1)]])

    # exceptions
    assert Batch.cat([]).is_empty()
    assert Batch.stack([]).is_empty()
    b1 = Batch(e=[4, 5], d=6)
    b2 = Batch(e=[4, 6])
    with pytest.raises(ValueError):
        Batch.cat([b1, b2])
    with pytest.raises(ValueError):
        Batch.stack([b1, b2], axis=1)