def test_rotate_90_transforms_have_same_behaviour(k, img_6x6_rgb): dc = slc.DataContainer(img_6x6_rgb, "I") trf_1 = slt.Rotate(angle_range=(k * 90, k * 90), p=1) trf_1.sample_transform(dc) trf_2 = slt.Rotate90(k=k, p=1) trf_2.sample_transform(dc) assert np.array_equal(trf_1.state_dict["transform_matrix"], trf_2.state_dict["transform_matrix"])
def create_train_transforms(size): return solt.Stream([ slt.JPEGCompression(p=0.5,quality_range=(60,100)), slt.Noise(p=0.25), slt.Brightness(), slt.Contrast(), slt.Flip(), slt.Rotate90(), solt.SelectiveStream([ slt.GammaCorrection(gamma_range=0.5, p=1), slt.Noise(gain_range=0.1, p=1), slt.SaltAndPepper(), slt.Blur(), ], n=3), slt.Rotate(angle_range=(-10, 10), p=0.5), slt.Resize((size,size)), ])
def test_rotate_90_img_mask_nondestructive(k, img, mask): # Setting up the data H, W = mask.shape dc = slc.DataContainer((img, mask), "IM") # Defining the 90 degrees transform (counterclockwise) stream = slt.Rotate90(k=k, p=1) dc_res = stream(dc) img_res, _, _ = dc_res[0] mask_res, _, _ = dc_res[1] expected_img_res = np.rot90(img, -k).reshape((H, W, 1)) expected_mask_res = np.rot90(mask, -k) assert np.array_equal(expected_img_res, img_res) assert np.array_equal(expected_mask_res, mask_res)
def test_complex_transform_serialization(): stream = slc.Stream([ slt.Flip(axis=1, p=0.5), slc.SelectiveStream([ slt.Rotate(angle_range=(-45, -45), p=1, padding="r"), slt.Rotate90(1, p=1), slt.Rotate(angle_range=(45, 45), p=1, padding="r"), ]), slt.Crop((350, 350)), slc.SelectiveStream([ slt.GammaCorrection(gamma_range=0.5, p=1), slt.Noise(gain_range=0.1, p=1), slt.Blur() ], n=3), slt.Projection( affine_transforms=slc.Stream([ slt.Rotate(angle_range=(-45, 45), p=1), slt.Scale(range_x=(0.8, 1.5), range_y=(0.8, 1.5), p=1, same=False), ]), v_range=(1e-4, 1e-3), p=1, ), slc.SelectiveStream( [ slt.CutOut(40, p=1), slt.CutOut(30, p=1), slt.CutOut(20, p=1), slt.CutOut(40, p=1), slc.Stream(), slc.Stream(), slc.Stream(), ], n=3, ), ]) assert slu.from_yaml(stream.to_yaml()).to_yaml() == slu.from_yaml( stream.to_yaml()).to_yaml()
def test_rotate_nondestructive_does_not_accept_non_int_k(k): with pytest.raises(TypeError): slt.Rotate90(k=k)