示例#1
0
def test_rotate_identity():
    g = Grid(
        np.arange(27).reshape((1, 3, 3, 3)), np.array([0, 0, 0]),
        np.array([1, 1, 1]))
    before = g.grid.copy()
    g.move(np.eye(3), np.array([0, 0, 0]))
    np.testing.assert_array_equal(before, g.grid)
示例#2
0
def test_save(tmpdir):
    g = Grid(np.ones((2, 3, 3, 3)),
             origin=np.ones(3),
             delta=np.ones(3),
             names=['test1', 'test2'])
    g.save(tmpdir / 'grid.list')
    new = Grid(grid_list=tmpdir / 'grid.list')
    assert np.all(new.grid - g.grid < 0.0001)
    assert all([x == y for x, y in zip(new.names, g.names)])
示例#3
0
def test_translate():
    g = Grid(
        np.arange(27).reshape((1, 3, 3, 3)), np.array([0, 0, 0]),
        np.array([1, 1, 1]))
    before = g.grid.copy()
    g.move(np.eye(3), np.array([1, 1, 1]))

    manually_translated = np.zeros_like(g.grid)
    manually_translated[0, 1:, 1:, 1:] = before[0, :2, :2, :2]

    # compare rotated grids
    np.testing.assert_array_equal(manually_translated, g.grid)
示例#4
0
def test_create_2():
    with pytest.raises(RuntimeError):
        g = Grid(np.ones((
            1,
            3,
            3,
        )), origin=np.ones(3), delta=np.ones(3))
示例#5
0
def test_rotate_180_around_Z():
    rot_mat = _get_rotation_matrix_around_Z(np.pi)
    g = Grid(
        np.arange(27).reshape((1, 3, 3, 3)), np.array([0, 0, 0]),
        np.array([1, 1, 1]))
    before = g.grid.copy()
    g.move(rot_mat, np.array([0, 0, 0]))

    # flip entries to match the rotated grid
    manually_rotated = np.zeros_like(g.grid)
    for k in range(g.grid.shape[3]):
        for _i, _j in product(range(3), range(3)):
            j, i = INDEX_MAP_180[_j][_i]
            manually_rotated[0, i, j, k] = before[0, _i, _j, k]

    # compare rotated grids
    np.testing.assert_array_equal(manually_rotated, g.grid)
示例#6
0
def test_copy():
    g = Grid(np.ones((2, 3, 3, 3)),
             origin=np.ones(3),
             delta=np.ones(3),
             names=None)
    g_copy = g.copy()
示例#7
0
def test_create_4():
    with pytest.raises(RuntimeError):
        g = Grid(torch.ones((1, 3, 3, 3)),
                 origin=torch.ones(3),
                 delta=torch.ones(3))
示例#8
0
def test_create_1():
    g = Grid(np.ones((1, 3, 3, 3)), origin=np.ones(3), delta=np.ones(3))
    assert len(g.grid.shape) == 4