コード例 #1
0
class TestApplyToKeys:
    @pytest.mark.parametrize(
        "sample, keys, expected",
        [
            ({DataKeys.INPUT: "test"}, DataKeys.INPUT, "test"),
            (
                {DataKeys.INPUT: "test_a", DataKeys.TARGET: "test_b"},
                [DataKeys.INPUT, DataKeys.TARGET],
                ["test_a", "test_b"],
            ),
            ({"input": "test"}, "input", "test"),
            ({"input": "test_a", "target": "test_b"}, ["input", "target"], ["test_a", "test_b"]),
            ({"input": "test_a", "target": "test_b", "extra": "..."}, ["input", "target"], ["test_a", "test_b"]),
            ({"input": "test_a", "target": "test_b"}, ["input", "target", "extra"], ["test_a", "test_b"]),
            ({"target": "..."}, "input", None),
        ],
    )
    def test_forward(self, sample, keys, expected):
        transform = Mock(return_value=["out"] * len(keys))
        ApplyToKeys(keys, transform)(sample)
        if expected is not None:
            transform.assert_called_once_with(expected)
        else:
            transform.assert_not_called()

    @pytest.mark.parametrize(
        "transform, expected",
        [
            (
                ApplyToKeys(DataKeys.INPUT, torch.nn.ReLU()),
                "ApplyToKeys(keys=<DataKeys.INPUT: 'input'>, transform=ReLU())",
            ),
            (
                ApplyToKeys([DataKeys.INPUT, DataKeys.TARGET], torch.nn.ReLU()),
                "ApplyToKeys(keys=[<DataKeys.INPUT: 'input'>, " "<DataKeys.TARGET: 'target'>], transform=ReLU())",
            ),
            (ApplyToKeys("input", torch.nn.ReLU()), "ApplyToKeys(keys='input', transform=ReLU())"),
            (
                ApplyToKeys(["input", "target"], torch.nn.ReLU()),
                "ApplyToKeys(keys=['input', 'target'], transform=ReLU())",
            ),
        ],
    )
    def test_repr(self, transform, expected):
        assert repr(transform) == expected
コード例 #2
0
 def per_batch_transform_on_device(self) -> Callable:
     return ApplyToKeys(
         "video",
         K.VideoSequential(
             K.Normalize(self.mean, self.std),
             data_format=self.data_format,
             same_on_frame=self.same_on_frame,
         ),
     )
コード例 #3
0
    def per_sample_transform(self) -> Callable:
        per_sample_transform = [CenterCrop(self.image_size)]

        return ApplyToKeys(
            "video",
            Compose([
                UniformTemporalSubsample(self.temporal_sub_sample), normalize
            ] + per_sample_transform),
        )
コード例 #4
0
ファイル: data.py プロジェクト: stjordanis/lightning-flash
    def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]:
        default_transforms = default_transforms_fn(*args, **kwargs)
        if not default_transforms:
            return default_transforms

        return {
            hook: ApplyToKeys(keys, transform)
            for hook, transform in default_transforms.items()
        }
コード例 #5
0
def test_from_filepaths_splits(tmpdir):
    tmpdir = Path(tmpdir)

    B, _, H, W = 2, 3, 224, 224
    img_size: Tuple[int, int] = (H, W)

    (tmpdir / "splits").mkdir()
    _rand_image(img_size).save(tmpdir / "s.png")

    num_samples: int = 10
    val_split: float = 0.3

    train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)]

    train_labels: List[int] = list(range(num_samples))

    assert len(train_filepaths) == len(train_labels)

    _to_tensor = {
        "to_tensor_transform": nn.Sequential(
            ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
            ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
        ),
    }

    def run(transform: Any = None):
        dm = ImageClassificationData.from_files(
            train_files=train_filepaths,
            train_targets=train_labels,
            train_transform=transform,
            val_transform=transform,
            batch_size=B,
            num_workers=0,
            val_split=val_split,
            image_size=img_size,
        )
        data = next(iter(dm.train_dataloader()))
        imgs, labels = data["input"], data["target"]
        assert imgs.shape == (B, 3, H, W)
        assert labels.shape == (B,)

    run(_to_tensor)
コード例 #6
0
    def train_per_sample_transform(self) -> Callable:
        per_sample_transform = [
            RandomCrop(self.image_size, pad_if_needed=True)
        ]

        return ApplyToKeys(
            "video",
            Compose([
                UniformTemporalSubsample(self.temporal_sub_sample), normalize
            ] + per_sample_transform),
        )
コード例 #7
0
def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
    """During training, we apply the default transforms with additional ``RandomHorizontalFlip`` and ``ColorJitter``."""
    return merge_transforms(
        default_transforms(image_size), {
            "post_tensor_transform": nn.Sequential(
                ApplyToKeys(
                    [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
                    KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)),
                ),
            ),
        }
    )
コード例 #8
0
def train_default_transforms(
        spectrogram_size: Tuple[int, int], time_mask_param: Optional[int],
        freq_mask_param: Optional[int]) -> Dict[str, Callable]:
    """During training we apply the default transforms with optional ``TimeMasking`` and ``Frequency Masking``."""
    augs = []

    if time_mask_param is not None:
        augs.append(
            ApplyToKeys(DefaultDataKeys.INPUT,
                        TAudio.TimeMasking(time_mask_param=time_mask_param)))

    if freq_mask_param is not None:
        augs.append(
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)))

    if len(augs) > 0:
        return merge_transforms(
            default_transforms(spectrogram_size),
            {"post_tensor_transform": nn.Sequential(*augs)})
    return default_transforms(spectrogram_size)
コード例 #9
0
def predict_default_transforms(
        image_size: Tuple[int, int]) -> Dict[str, Callable]:
    """During predict, we apply the default transforms only on DefaultDataKeys.INPUT."""
    return {
        "post_tensor_transform":
        nn.Sequential(
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                K.geometry.Resize(image_size, interpolation="nearest"),
            ), ),
        "collate":
        kornia_collate,
    }
コード例 #10
0
def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
    """The default transforms for image classification: resize the image, convert the image and target to a tensor,
    collate the batch, and apply normalization."""
    if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
        #  Better approach as all transforms are applied on tensor directly
        return {
            "to_tensor_transform":
            nn.Sequential(
                ApplyToKeys(DefaultDataKeys.INPUT,
                            torchvision.transforms.ToTensor()),
                ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
            ),
            "post_tensor_transform":
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                K.geometry.Resize(image_size),
            ),
            "collate":
            kornia_collate,
            "per_batch_transform_on_device":
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]),
                                         torch.tensor([0.229, 0.224, 0.225])),
            )
        }
    else:
        return {
            "pre_tensor_transform":
            ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)),
            "to_tensor_transform":
            nn.Sequential(
                ApplyToKeys(DefaultDataKeys.INPUT,
                            torchvision.transforms.ToTensor()),
                ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
            ),
            "post_tensor_transform":
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ),
            "collate":
            kornia_collate,
        }
コード例 #11
0
def default_transforms() -> Dict[str, Callable]:
    """The default transforms for object detection: convert the image and targets to a tensor, collate the batch."""
    return {
        "to_tensor_transform":
        nn.Sequential(
            ApplyToKeys('input', torchvision.transforms.ToTensor()),
            ApplyToKeys(
                'target',
                nn.Sequential(
                    ApplyToKeys('boxes', torch.as_tensor),
                    ApplyToKeys('labels', torch.as_tensor),
                    ApplyToKeys('image_id', torch.as_tensor),
                    ApplyToKeys('area', torch.as_tensor),
                    ApplyToKeys('iscrowd', torch.as_tensor),
                )),
        ),
        "collate":
        collate,
    }
コード例 #12
0
def test_multicrop_input_transform():
    batch_size = 8
    total_crops = 6
    num_crops = [2, 4]
    size_crops = [160, 96]
    crop_scales = [[0.4, 1], [0.05, 0.4]]

    multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"](
        total_crops, num_crops, size_crops, crop_scales)

    to_tensor_transform = ApplyToKeys(
        DefaultDataKeys.INPUT,
        multi_crop_transform,
    )
    preprocess = DefaultPreprocess(train_transform={
        "to_tensor_transform": to_tensor_transform,
        "collate": vissl_collate_fn,
    })

    datamodule = ImageClassificationData.from_datasets(
        train_dataset=FakeData(),
        preprocess=preprocess,
        batch_size=batch_size,
    )

    train_dataloader = datamodule._train_dataloader()
    batch = next(iter(train_dataloader))

    assert len(batch[DefaultDataKeys.INPUT]) == total_crops
    assert batch[DefaultDataKeys.INPUT][0].shape == (batch_size, 3,
                                                     size_crops[0],
                                                     size_crops[0])
    assert batch[DefaultDataKeys.INPUT][-1].shape == (batch_size, 3,
                                                      size_crops[-1],
                                                      size_crops[-1])
    assert list(batch[DefaultDataKeys.TARGET].shape) == [batch_size]
コード例 #13
0
 def target_per_sample_transform(self) -> Callable:
     return ApplyToKeys("target_boxes", torch.as_tensor)
コード例 #14
0
    def __resolve_transforms(
            self,
            running_stage: RunningStage) -> Optional[Dict[str, Callable]]:
        from flash.core.data.data_pipeline import DataPipeline

        transforms_out = {}
        stage = _STAGES_PREFIX[running_stage]

        # iterate over all transforms hook name
        for transform_name in InputTransformPlacement:

            transforms = {}
            transform_name = transform_name.value

            # iterate over all prefixes
            for key in ApplyToKeyPrefix:

                # get the resolved hook name based on the current stage
                resolved_name = DataPipeline._resolve_function_hierarchy(
                    transform_name, self, running_stage, InputTransform)
                # check if the hook name is specialized
                is_specialized_name = resolved_name.startswith(stage)

                # get the resolved hook name for apply to key on the current stage
                resolved_apply_to_key_name = DataPipeline._resolve_function_hierarchy(
                    f"{key}_{transform_name}", self, running_stage,
                    InputTransform)
                # check if resolved hook name for apply to key is specialized
                is_specialized_apply_to_key_name = resolved_apply_to_key_name.startswith(
                    stage)

                # check if they are overridden by the user
                resolve_name_overridden = DataPipeline._is_overridden(
                    resolved_name, self, InputTransform)
                resolved_apply_to_key_name_overridden = DataPipeline._is_overridden(
                    resolved_apply_to_key_name, self, InputTransform)

                if resolve_name_overridden and resolved_apply_to_key_name_overridden:
                    # if both are specialized or both aren't specialized, raise a exception
                    # It means there is priority to specialize hooks name.
                    if not (is_specialized_name
                            ^ is_specialized_apply_to_key_name):
                        raise MisconfigurationException(
                            f"Only one of {resolved_name} or {resolved_apply_to_key_name} can be overridden."
                        )

                    method_name = resolved_name if is_specialized_name else resolved_apply_to_key_name
                else:
                    method_name = resolved_apply_to_key_name if resolved_apply_to_key_name_overridden else resolved_name

                # get associated transform
                try:
                    fn = getattr(self, method_name)()
                except AttributeError as e:
                    raise AttributeError(
                        str(e) +
                        ". Hint: Call super().__init__(...) after setting all attributes."
                    )

                if not callable(fn):
                    raise MisconfigurationException(
                        f"The hook {method_name} should return a function.")

                # if the default hook is used, it should return identity, skip it.
                if fn is self._identity:
                    continue

                # wrap apply to key hook into `ApplyToKeys` with the associated key.
                if method_name == resolved_apply_to_key_name:
                    fn = ApplyToKeys(key.value, fn)

                if method_name not in transforms:
                    transforms[method_name] = fn

            # store the transforms.
            if transforms:
                transforms = list(transforms.values())
                transforms_out[transform_name] = Compose(
                    transforms) if len(transforms) > 1 else transforms[0]

        return transforms_out
 "per_sample_transform":
 nn.Sequential(
     ApplyToKeys(
         DataKeys.INPUT,
         nn.Sequential(
             torchvision.transforms.ToTensor(),
             Kg.Resize((196, 196)),
             # SPATIAL
             Ka.RandomHorizontalFlip(p=0.25),
             Ka.RandomRotation(degrees=90.0, p=0.25),
             Ka.RandomAffine(degrees=1 * 5.0,
                             shear=1 / 5,
                             translate=1 / 20,
                             p=0.25),
             Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
             # PIXEL-LEVEL
             Ka.ColorJitter(brightness=1 / 30, p=0.25),  # brightness
             Ka.ColorJitter(saturation=1 / 30, p=0.25),  # saturation
             Ka.ColorJitter(contrast=1 / 30, p=0.25),  # contrast
             Ka.ColorJitter(hue=1 / 30, p=0.25),  # hue
             Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1,
                                 angle=1,
                                 direction=1.0,
                                 p=0.25),
             Ka.RandomErasing(scale=(1 / 100, 1 / 50),
                              ratio=(1 / 20, 1),
                              p=0.25),
         ),
     ),
     ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
 ),
コード例 #16
0
 def per_sample_transform(self) -> Callable:
     return ApplyToKeys(
         [DataKeys.INPUT, DataKeys.TARGET],
         KorniaParallelTransforms(
             K.geometry.Resize(self.image_size, interpolation="nearest")),
     )