Ejemplo n.º 1
0
def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
    return {
        "post_tensor_transform":
        nn.Sequential(
            ApplyToKeys(
                [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
                KorniaParallelTransforms(
                    K.geometry.Resize(image_size, interpolation='nearest')),
            ),
            ApplyToKeys(DefaultDataKeys.TARGET, prepare_target),
        ),
        "per_batch_transform_on_device":
        ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)),
    }
Ejemplo n.º 2
0
 def to_tensor(self) -> Dict[str, Callable]:
     return {
         "to_tensor_transform":
         nn.Sequential(
             ApplyToKeys(
                 DefaultDataKeys.INPUT,
                 torch.from_numpy,
                 self.to_float,
             ),
             ApplyToKeys(
                 DefaultDataKeys.TARGET,
                 torch.as_tensor,
                 self.to_float,
                 self.format_targets,
             ),
         ),
     }
Ejemplo n.º 3
0
def default_train_transforms(
        image_size: Tuple[int, int]) -> Dict[str, Callable]:
    return {
        "post_tensor_transform":
        nn.Sequential(
            ApplyToKeys(
                [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
                KorniaParallelTransforms(
                    K.geometry.Resize(image_size, interpolation='nearest'),
                    K.augmentation.RandomHorizontalFlip(p=0.75),
                ),
            ),
            ApplyToKeys(DefaultDataKeys.TARGET, prepare_target),
        ),
        "per_batch_transform_on_device":
        ApplyToKeys(
            DefaultDataKeys.INPUT,
            K.enhance.Normalize(0., 255.),
            K.augmentation.ColorJitter(0.4, p=0.5),
        ),
    }
Ejemplo n.º 4
0
def default_train_transforms(
        image_size: Tuple[int, int]) -> Dict[str, Callable]:
    if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
        #  Better approach as all transforms are applied on tensor directly
        return {
            "to_tensor_transform":
            nn.Sequential(
                ApplyToKeys(DefaultDataKeys.INPUT,
                            torchvision.transforms.ToTensor()),
                ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
            ),
            "post_tensor_transform":
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                K.geometry.Resize(image_size),
                K.augmentation.RandomHorizontalFlip(),
            ),
            "collate":
            kornia_collate,
            "per_batch_transform_on_device":
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]),
                                         torch.tensor([0.229, 0.224, 0.225])),
            )
        }
    else:
        return {
            "pre_tensor_transform":
            ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size),
                        T.RandomHorizontalFlip()),
            "to_tensor_transform":
            nn.Sequential(
                ApplyToKeys(DefaultDataKeys.INPUT,
                            torchvision.transforms.ToTensor()),
                ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
            ),
            "post_tensor_transform":
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ),
            "collate":
            kornia_collate,
        }
Ejemplo n.º 5
0
def default_transforms() -> Dict[str, Callable]:
    return {
        "to_tensor_transform": nn.Sequential(
            ApplyToKeys('input', torchvision.transforms.ToTensor()),
            ApplyToKeys(
                'target',
                nn.Sequential(
                    ApplyToKeys('boxes', torch.as_tensor),
                    ApplyToKeys('labels', torch.as_tensor),
                    ApplyToKeys('image_id', torch.as_tensor),
                    ApplyToKeys('area', torch.as_tensor),
                    ApplyToKeys('iscrowd', torch.as_tensor),
                )
            ),
        ),
    }