def test_template_transform_incorrect_size(template): image = np.random.randint(0, 256, (512, 512, 3), np.uint8) with pytest.raises(ValueError) as exc_info: transform = A.TemplateTransform(template, always_apply=True) transform(image=image) message = "Image and template must be the same size, got {} and {}".format( image.shape[:2], template.shape[:2]) assert str(exc_info.value) == message
def test_template_transform_incorrect_channels(img_channels, template_channels): img = np.random.randint(0, 256, [512, 512, img_channels], np.uint8) template = np.random.randint(0, 256, [512, 512, template_channels], np.uint8) with pytest.raises(ValueError) as exc_info: transform = A.TemplateTransform(template, always_apply=True) transform(image=img) message = ( "Template must be a single channel or has the same number of channels " "as input image ({}), got {}".format(img_channels, template.shape[-1])) assert str(exc_info.value) == message
def test_template_transform(image, img_weight, template_weight, template_transform, image_size, template_size): img = np.random.randint(0, 256, image_size, np.uint8) template = np.random.randint(0, 256, template_size, np.uint8) aug = A.TemplateTransform(template, img_weight, template_weight, template_transform) result = aug(image=img)["image"] assert result.shape == img.shape params = aug.get_params_dependent_on_targets({"image": img}) template = params["template"] assert template.shape == img.shape assert template.dtype == img.dtype
def test_template_transform_serialization(image, template, seed, p): template_transform = A.TemplateTransform(name="template", templates=template, p=p) aug = A.Compose([A.Flip(), template_transform, A.Blur()]) serialized_aug = A.to_dict(aug) deserialized_aug = A.from_dict( serialized_aug, lambda_transforms={"template": template_transform}) set_seed(seed) aug_data = aug(image=image) set_seed(seed) deserialized_aug_data = deserialized_aug(image=image) assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])