def test_rotate_90_trnsforms_have_same_bahaviour(k): trf_1 = slt.RandomRotate(rotation_range=(k * 90, k * 90), p=1) trf_1.sample_transform() trf_2 = slt.RandomRotate90(k=k, p=1) trf_2.sample_transform() assert np.array_equal(trf_1.state_dict['transform_matrix'], trf_2.state_dict['transform_matrix'])
def test_rotate_90_img_mask_nondestructive(k, img_3x3, mask_3x3): # Setting up the data img, mask = img_3x3, mask_3x3 H, W = mask.shape dc = sld.DataContainer((img, mask), 'IM') # Defining the 90 degrees transform (counterclockwise) stream = slt.RandomRotate90(k=k, p=1) dc_res = stream(dc) img_res, _, _ = dc_res[0] mask_res, _, _ = dc_res[1] expected_img_res = np.rot90(img, -k).reshape((H, W, 1)) expected_mask_res = np.rot90(mask, -k) assert np.array_equal(expected_img_res, img_res) assert np.array_equal(expected_mask_res, mask_res)
def test_rotate_nondestructive_does_not_accept_non_int_k(k): with pytest.raises(TypeError): slt.RandomRotate90(k=k)