def test_batch_over_batch(): batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) batch2 = Batch({'c': [6, 7, 8], 'b': batch}) batch2.b.b[-1] = 0 print(batch2) for k, v in batch2.items(): assert np.all(batch2[k] == v) assert batch2[-1].b.b == 0 batch2.cat_(Batch(c=[6, 7, 8], b=batch)) assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) batch4 = Batch(({'a': {'b': np.array([1.0])}}, )) assert batch4.a.b.ndim == 2 assert batch4.a.b[0, 0] == 1.0 # advanced slicing batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])}) assert batch5.shape == [1, 2] with pytest.raises(IndexError): batch5[2] with pytest.raises(IndexError): batch5[:, 3] with pytest.raises(IndexError): batch5[:, :, -1] batch5[:, -1] += 1 assert np.allclose(batch5.a, [1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
def _batch_set_item( source: Batch, indices: np.ndarray, target: Batch, size: int ) -> None: # for any key chain k, there are four cases # 1. source[k] is non-reserved, but target[k] does not exist or is reserved # 2. source[k] does not exist or is reserved, but target[k] is non-reserved # 3. both source[k] and target[k] are non-reserved # 4. both source[k] and target[k] do not exist or are reserved, do nothing. # A special case in case 4, if target[k] is reserved but source[k] does # not exist, make source[k] reserved, too. for k, vt in target.items(): if not isinstance(vt, Batch) or not vt.is_empty(): # target[k] is non-reserved vs = source.get(k, Batch()) if isinstance(vs, Batch): if vs.is_empty(): # case 2, use __dict__ to avoid many type checks source.__dict__[k] = _create_value(vt[0], size) else: assert isinstance(vt, Batch) _batch_set_item(source.__dict__[k], indices, vt, size) else: # target[k] is reserved # case 1 or special case of case 4 if k not in source.__dict__: source.__dict__[k] = Batch() continue source.__dict__[k][indices] = vt
def test_batch_over_batch(): batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) batch2 = Batch({'c': [6, 7, 8], 'b': batch}) batch2.b.b[-1] = 0 print(batch2) for k, v in batch2.items(): assert batch2[k] == v assert batch2[-1].b.b == 0 batch2.cat_(Batch(c=[6, 7, 8], b=batch)) assert batch2.c == [6, 7, 8, 6, 7, 8] assert batch2.b.a == [3, 4, 5, 3, 4, 5] assert batch2.b.b == [4, 5, 0, 4, 5, 0] d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) assert batch3.c == [6, 7, 8, 6, 7, 8] assert batch3.b.a == [3, 4, 5, 3, 4, 5] assert batch3.b.b == [4, 5, 6, 4, 5, 6] batch4 = Batch(({'a': {'b': np.array([1.0])}},)) assert batch4.a.b.ndim == 2 assert batch4.a.b[0, 0] == 1.0