Ejemplo n.º 1
0
def test_base_storage():
    storage = BaseStorage()
    storage.x = torch.zeros(1)
    storage.y = torch.ones(1)
    assert len(storage) == 2
    assert storage.x is not None
    assert storage.y is not None

    assert torch.allclose(storage.get('x', None), storage.x)
    assert torch.allclose(storage.get('y', None), storage.y)
    assert storage.get('z', 2) == 2
    assert storage.get('z', None) is None

    assert len(list(storage.keys('x', 'y', 'z'))) == 2
    assert len(list(storage.keys('x', 'y', 'z'))) == 2
    assert len(list(storage.values('x', 'y', 'z'))) == 2
    assert len(list(storage.items('x', 'y', 'z'))) == 2

    del storage.y
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage({'x': torch.zeros(1)})
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage(x=torch.zeros(1))
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage(x=torch.zeros(1))
    copied_storage = copy.copy(storage)
    assert storage == copied_storage
    assert id(storage) != id(copied_storage)
    assert storage.x.data_ptr() == copied_storage.x.data_ptr()
    assert int(storage.x) == 0
    assert int(copied_storage.x) == 0

    deepcopied_storage = copy.deepcopy(storage)
    assert storage == deepcopied_storage
    assert id(storage) != id(deepcopied_storage)
    assert storage.x.data_ptr() != deepcopied_storage.x.data_ptr()
    assert int(storage.x) == 0
    assert int(deepcopied_storage.x) == 0
Ejemplo n.º 2
0
def test_base_storage():
    storage = BaseStorage(_parent=None)
    storage.x = torch.zeros(1)
    storage.y = torch.ones(1)
    assert len(storage) == 2
    assert storage.x is not None
    assert storage.y is not None

    assert len(list(storage.keys('x', 'y', 'z'))) == 2
    assert len(list(storage.keys('x', 'y', 'z'))) == 2
    assert len(list(storage.values('x', 'y', 'z'))) == 2
    assert len(list(storage.items('x', 'y', 'z'))) == 2

    del storage.y
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage(_parent=None, _mapping={'x': torch.zeros(1)})
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage(_parent=None, x=torch.zeros(1))
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage(_parent=None, x=torch.zeros(1))
    copied_storage = copy.copy(storage)
    assert storage == copied_storage
    assert id(storage) != id(copied_storage)
    assert storage.x.data_ptr() == copied_storage.x.data_ptr()
    assert int(storage.x) == 0
    assert int(copied_storage.x) == 0

    deepcopied_storage = copy.deepcopy(storage)
    assert storage == deepcopied_storage
    assert id(storage) != id(deepcopied_storage)
    assert storage.x.data_ptr() != deepcopied_storage.x.data_ptr()
    assert int(storage.x) == 0
    assert int(deepcopied_storage.x) == 0