def test_random_shape(self, input_param, input_shape, expected_shape): for im_type in TEST_NDARRAYS_ALL: with self.subTest(im_type=im_type): cropper = RandSpatialCrop(**input_param) cropper.set_random_state(seed=123) input_data = im_type(np.random.randint(0, 2, input_shape)) result = cropper(input_data) self.assertTupleEqual(result.shape, expected_shape)
def test_random_shape(self, input_param, input_data, expected_shape): cropper = RandSpatialCrop(**input_param) cropper.set_random_state(seed=123) result = cropper(input_data) self.assertTupleEqual(result.shape, expected_shape)