class TestApplyToKeys: @pytest.mark.parametrize( "sample, keys, expected", [ ({DataKeys.INPUT: "test"}, DataKeys.INPUT, "test"), ( {DataKeys.INPUT: "test_a", DataKeys.TARGET: "test_b"}, [DataKeys.INPUT, DataKeys.TARGET], ["test_a", "test_b"], ), ({"input": "test"}, "input", "test"), ({"input": "test_a", "target": "test_b"}, ["input", "target"], ["test_a", "test_b"]), ({"input": "test_a", "target": "test_b", "extra": "..."}, ["input", "target"], ["test_a", "test_b"]), ({"input": "test_a", "target": "test_b"}, ["input", "target", "extra"], ["test_a", "test_b"]), ({"target": "..."}, "input", None), ], ) def test_forward(self, sample, keys, expected): transform = Mock(return_value=["out"] * len(keys)) ApplyToKeys(keys, transform)(sample) if expected is not None: transform.assert_called_once_with(expected) else: transform.assert_not_called() @pytest.mark.parametrize( "transform, expected", [ ( ApplyToKeys(DataKeys.INPUT, torch.nn.ReLU()), "ApplyToKeys(keys=<DataKeys.INPUT: 'input'>, transform=ReLU())", ), ( ApplyToKeys([DataKeys.INPUT, DataKeys.TARGET], torch.nn.ReLU()), "ApplyToKeys(keys=[<DataKeys.INPUT: 'input'>, " "<DataKeys.TARGET: 'target'>], transform=ReLU())", ), (ApplyToKeys("input", torch.nn.ReLU()), "ApplyToKeys(keys='input', transform=ReLU())"), ( ApplyToKeys(["input", "target"], torch.nn.ReLU()), "ApplyToKeys(keys=['input', 'target'], transform=ReLU())", ), ], ) def test_repr(self, transform, expected): assert repr(transform) == expected
def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys( "video", K.VideoSequential( K.Normalize(self.mean, self.std), data_format=self.data_format, same_on_frame=self.same_on_frame, ), )
def per_sample_transform(self) -> Callable: per_sample_transform = [CenterCrop(self.image_size)] return ApplyToKeys( "video", Compose([ UniformTemporalSubsample(self.temporal_sub_sample), normalize ] + per_sample_transform), )
def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: default_transforms = default_transforms_fn(*args, **kwargs) if not default_transforms: return default_transforms return { hook: ApplyToKeys(keys, transform) for hook, transform in default_transforms.items() }
def test_from_filepaths_splits(tmpdir): tmpdir = Path(tmpdir) B, _, H, W = 2, 3, 224, 224 img_size: Tuple[int, int] = (H, W) (tmpdir / "splits").mkdir() _rand_image(img_size).save(tmpdir / "s.png") num_samples: int = 10 val_split: float = 0.3 train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)] train_labels: List[int] = list(range(num_samples)) assert len(train_filepaths) == len(train_labels) _to_tensor = { "to_tensor_transform": nn.Sequential( ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), } def run(transform: Any = None): dm = ImageClassificationData.from_files( train_files=train_filepaths, train_targets=train_labels, train_transform=transform, val_transform=transform, batch_size=B, num_workers=0, val_split=val_split, image_size=img_size, ) data = next(iter(dm.train_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (B, 3, H, W) assert labels.shape == (B,) run(_to_tensor)
def train_per_sample_transform(self) -> Callable: per_sample_transform = [ RandomCrop(self.image_size, pad_if_needed=True) ] return ApplyToKeys( "video", Compose([ UniformTemporalSubsample(self.temporal_sub_sample), normalize ] + per_sample_transform), )
def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """During training, we apply the default transforms with additional ``RandomHorizontalFlip`` and ``ColorJitter``.""" return merge_transforms( default_transforms(image_size), { "post_tensor_transform": nn.Sequential( ApplyToKeys( [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)), ), ), } )
def train_default_transforms( spectrogram_size: Tuple[int, int], time_mask_param: Optional[int], freq_mask_param: Optional[int]) -> Dict[str, Callable]: """During training we apply the default transforms with optional ``TimeMasking`` and ``Frequency Masking``.""" augs = [] if time_mask_param is not None: augs.append( ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param))) if freq_mask_param is not None: augs.append( ApplyToKeys( DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))) if len(augs) > 0: return merge_transforms( default_transforms(spectrogram_size), {"post_tensor_transform": nn.Sequential(*augs)}) return default_transforms(spectrogram_size)
def predict_default_transforms( image_size: Tuple[int, int]) -> Dict[str, Callable]: """During predict, we apply the default transforms only on DefaultDataKeys.INPUT.""" return { "post_tensor_transform": nn.Sequential( ApplyToKeys( DefaultDataKeys.INPUT, K.geometry.Resize(image_size, interpolation="nearest"), ), ), "collate": kornia_collate, }
def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """The default transforms for image classification: resize the image, convert the image and target to a tensor, collate the batch, and apply normalization.""" if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { "to_tensor_transform": nn.Sequential( ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( DefaultDataKeys.INPUT, K.geometry.Resize(image_size), ), "collate": kornia_collate, "per_batch_transform_on_device": ApplyToKeys( DefaultDataKeys.INPUT, K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ) } else: return { "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)), "to_tensor_transform": nn.Sequential( ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( DefaultDataKeys.INPUT, T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ), "collate": kornia_collate, }
def default_transforms() -> Dict[str, Callable]: """The default transforms for object detection: convert the image and targets to a tensor, collate the batch.""" return { "to_tensor_transform": nn.Sequential( ApplyToKeys('input', torchvision.transforms.ToTensor()), ApplyToKeys( 'target', nn.Sequential( ApplyToKeys('boxes', torch.as_tensor), ApplyToKeys('labels', torch.as_tensor), ApplyToKeys('image_id', torch.as_tensor), ApplyToKeys('area', torch.as_tensor), ApplyToKeys('iscrowd', torch.as_tensor), )), ), "collate": collate, }
def test_multicrop_input_transform(): batch_size = 8 total_crops = 6 num_crops = [2, 4] size_crops = [160, 96] crop_scales = [[0.4, 1], [0.05, 0.4]] multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( total_crops, num_crops, size_crops, crop_scales) to_tensor_transform = ApplyToKeys( DefaultDataKeys.INPUT, multi_crop_transform, ) preprocess = DefaultPreprocess(train_transform={ "to_tensor_transform": to_tensor_transform, "collate": vissl_collate_fn, }) datamodule = ImageClassificationData.from_datasets( train_dataset=FakeData(), preprocess=preprocess, batch_size=batch_size, ) train_dataloader = datamodule._train_dataloader() batch = next(iter(train_dataloader)) assert len(batch[DefaultDataKeys.INPUT]) == total_crops assert batch[DefaultDataKeys.INPUT][0].shape == (batch_size, 3, size_crops[0], size_crops[0]) assert batch[DefaultDataKeys.INPUT][-1].shape == (batch_size, 3, size_crops[-1], size_crops[-1]) assert list(batch[DefaultDataKeys.TARGET].shape) == [batch_size]
def target_per_sample_transform(self) -> Callable: return ApplyToKeys("target_boxes", torch.as_tensor)
def __resolve_transforms( self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: from flash.core.data.data_pipeline import DataPipeline transforms_out = {} stage = _STAGES_PREFIX[running_stage] # iterate over all transforms hook name for transform_name in InputTransformPlacement: transforms = {} transform_name = transform_name.value # iterate over all prefixes for key in ApplyToKeyPrefix: # get the resolved hook name based on the current stage resolved_name = DataPipeline._resolve_function_hierarchy( transform_name, self, running_stage, InputTransform) # check if the hook name is specialized is_specialized_name = resolved_name.startswith(stage) # get the resolved hook name for apply to key on the current stage resolved_apply_to_key_name = DataPipeline._resolve_function_hierarchy( f"{key}_{transform_name}", self, running_stage, InputTransform) # check if resolved hook name for apply to key is specialized is_specialized_apply_to_key_name = resolved_apply_to_key_name.startswith( stage) # check if they are overridden by the user resolve_name_overridden = DataPipeline._is_overridden( resolved_name, self, InputTransform) resolved_apply_to_key_name_overridden = DataPipeline._is_overridden( resolved_apply_to_key_name, self, InputTransform) if resolve_name_overridden and resolved_apply_to_key_name_overridden: # if both are specialized or both aren't specialized, raise a exception # It means there is priority to specialize hooks name. if not (is_specialized_name ^ is_specialized_apply_to_key_name): raise MisconfigurationException( f"Only one of {resolved_name} or {resolved_apply_to_key_name} can be overridden." ) method_name = resolved_name if is_specialized_name else resolved_apply_to_key_name else: method_name = resolved_apply_to_key_name if resolved_apply_to_key_name_overridden else resolved_name # get associated transform try: fn = getattr(self, method_name)() except AttributeError as e: raise AttributeError( str(e) + ". Hint: Call super().__init__(...) after setting all attributes." ) if not callable(fn): raise MisconfigurationException( f"The hook {method_name} should return a function.") # if the default hook is used, it should return identity, skip it. if fn is self._identity: continue # wrap apply to key hook into `ApplyToKeys` with the associated key. if method_name == resolved_apply_to_key_name: fn = ApplyToKeys(key.value, fn) if method_name not in transforms: transforms[method_name] = fn # store the transforms. if transforms: transforms = list(transforms.values()) transforms_out[transform_name] = Compose( transforms) if len(transforms) > 1 else transforms[0] return transforms_out
"per_sample_transform": nn.Sequential( ApplyToKeys( DataKeys.INPUT, nn.Sequential( torchvision.transforms.ToTensor(), Kg.Resize((196, 196)), # SPATIAL Ka.RandomHorizontalFlip(p=0.25), Ka.RandomRotation(degrees=90.0, p=0.25), Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25), Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25), # PIXEL-LEVEL Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast Ka.ColorJitter(hue=1 / 30, p=0.25), # hue Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25), Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25), ), ), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ),
def per_sample_transform(self) -> Callable: return ApplyToKeys( [DataKeys.INPUT, DataKeys.TARGET], KorniaParallelTransforms( K.geometry.Resize(self.image_size, interpolation="nearest")), )