Exemple #1
0
def test_batch_over_batch():
    batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
    batch2 = Batch({'c': [6, 7, 8], 'b': batch})
    batch2.b.b[-1] = 0
    print(batch2)
    for k, v in batch2.items():
        assert np.all(batch2[k] == v)
    assert batch2[-1].b.b == 0
    batch2.cat_(Batch(c=[6, 7, 8], b=batch))
    assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
    assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
    assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
    d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
    batch3 = Batch(c=[6, 7, 8], b=d)
    batch3.cat_(Batch(c=[6, 7, 8], b=d))
    assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
    assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
    assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
    batch4 = Batch(({'a': {'b': np.array([1.0])}}, ))
    assert batch4.a.b.ndim == 2
    assert batch4.a.b[0, 0] == 1.0
    # advanced slicing
    batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])})
    assert batch5.shape == [1, 2]
    with pytest.raises(IndexError):
        batch5[2]
    with pytest.raises(IndexError):
        batch5[:, 3]
    with pytest.raises(IndexError):
        batch5[:, :, -1]
    batch5[:, -1] += 1
    assert np.allclose(batch5.a, [1, 3])
    assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
Exemple #2
0
def _batch_set_item(
    source: Batch, indices: np.ndarray, target: Batch, size: int
) -> None:
    # for any key chain k, there are four cases
    # 1. source[k] is non-reserved, but target[k] does not exist or is reserved
    # 2. source[k] does not exist or is reserved, but target[k] is non-reserved
    # 3. both source[k] and target[k] are non-reserved
    # 4. both source[k] and target[k] do not exist or are reserved, do nothing.
    # A special case in case 4, if target[k] is reserved but source[k] does
    # not exist, make source[k] reserved, too.
    for k, vt in target.items():
        if not isinstance(vt, Batch) or not vt.is_empty():
            # target[k] is non-reserved
            vs = source.get(k, Batch())
            if isinstance(vs, Batch):
                if vs.is_empty():
                    # case 2, use __dict__ to avoid many type checks
                    source.__dict__[k] = _create_value(vt[0], size)
                else:
                    assert isinstance(vt, Batch)
                    _batch_set_item(source.__dict__[k], indices, vt, size)
        else:
            # target[k] is reserved
            # case 1 or special case of case 4
            if k not in source.__dict__:
                source.__dict__[k] = Batch()
            continue
        source.__dict__[k][indices] = vt
Exemple #3
0
def test_batch_over_batch():
    batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
    batch2 = Batch({'c': [6, 7, 8], 'b': batch})
    batch2.b.b[-1] = 0
    print(batch2)
    for k, v in batch2.items():
        assert batch2[k] == v
    assert batch2[-1].b.b == 0
    batch2.cat_(Batch(c=[6, 7, 8], b=batch))
    assert batch2.c == [6, 7, 8, 6, 7, 8]
    assert batch2.b.a == [3, 4, 5, 3, 4, 5]
    assert batch2.b.b == [4, 5, 0, 4, 5, 0]
    d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
    batch3 = Batch(c=[6, 7, 8], b=d)
    batch3.cat_(Batch(c=[6, 7, 8], b=d))
    assert batch3.c == [6, 7, 8, 6, 7, 8]
    assert batch3.b.a == [3, 4, 5, 3, 4, 5]
    assert batch3.b.b == [4, 5, 6, 4, 5, 6]
    batch4 = Batch(({'a': {'b': np.array([1.0])}},))
    assert batch4.a.b.ndim == 2
    assert batch4.a.b[0, 0] == 1.0