def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={'c': {}}).is_empty() assert Batch(b={'c': {}}).is_empty(recurse=True) assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) assert not Batch(d=1).is_empty() assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) assert b.c.dtype == object b = Batch(d=[None], e=[starmap], f=Batch) assert b.d.dtype == b.e.dtype == object and b.f == Batch b = Batch() b.update() assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert 'a' in b and b.a == 3 assert b.pop('a') == 3 assert 'a' not in b with pytest.raises(AssertionError): Batch({1: 2}) assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch = Batch(a=np.arange(10)) with pytest.raises(AssertionError): list(batch.split(0)) data = [ (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), ] for size, merge_last, result in data: bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) assert [bs[i].a.tolist() for i in range(len(bs))] == result batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0))}]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert Batch(a=set((1, 2, 1))).shape == [] assert batch2.shape[0] == 1 assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2.a.d.f = {} batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): batch2 += [1] batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) with pytest.raises(ValueError): batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == object # auto convert to object batch4.update(a=np.array(['c', 'd'])) assert list(batch4.a) == ['c', 'd'] assert batch4.a.dtype == object # auto convert to object batch5 = Batch(a=np.array([{'index': 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) batch5.b = np.array([{'index': 1}]) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None and a.a[1] is None assert a.b[0] is None and a.b[1] is None # nx.Graph corner case assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object g1 = nx.Graph() g1.add_nodes_from(list(range(10))) g2 = nx.Graph() g2.add_nodes_from(list(range(20))) assert Batch(a=np.array([g1, g2])).a.dtype == object
def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={'c': {}}).is_empty() assert Batch(b={'c': {}}).is_empty(recurse=True) assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) assert not Batch(d=1).is_empty() assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch() b.update() assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert b.a == 3 with pytest.raises(AssertionError): Batch({1: 2}) with pytest.raises(TypeError): Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0)) }]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert batch2.shape[0] == 1 with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0])) }) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(KeyError): batch3.a.d[0] = Batch(f=5.0, g=0.0) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == np.object # auto convert to np.object batch4.update(a=np.array(['c', 'd'])) assert list(batch4.a) == ['c', 'd'] assert batch4.a.dtype == np.object # auto convert to np.object batch5 = Batch(a=np.array([{'index': 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) batch5.b = np.array([{'index': 1}]) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None and a.a[1] is None assert a.b[0] is None and a.b[1] is None