コード例 #1
0
    def __init__(
        self,
        train_transform: Optional[Union[Dict[str, Callable]]] = None,
        val_transform: Optional[Union[Dict[str, Callable]]] = None,
        test_transform: Optional[Union[Dict[str, Callable]]] = None,
        predict_transform: Optional[Union[Dict[str, Callable]]] = None,
        image_size: int = 256,
    ):
        if val_transform:
            raise_not_supported("validation")
        if test_transform:
            raise_not_supported("test")

        if isinstance(image_size, int):
            image_size = (image_size, image_size)

        self.image_size = image_size

        super().__init__(
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            predict_transform=predict_transform,
            data_sources={
                DefaultDataSources.FILES: ImagePathsDataSource(),
                DefaultDataSources.FOLDERS: ImagePathsDataSource(),
                DefaultDataSources.NUMPY: ImageNumpyDataSource(),
                DefaultDataSources.TENSORS: ImageTensorDataSource(),
                DefaultDataSources.TENSORS: ImageTensorDataSource(),
            },
            default_data_source=DefaultDataSources.FILES,
        )
コード例 #2
0
    def from_folders(
        cls,
        train_folder: Optional[Union[str, pathlib.Path]] = None,
        predict_folder: Optional[Union[str, pathlib.Path]] = None,
        train_transform: Optional[Union[str, Dict]] = None,
        predict_transform: Optional[Union[str, Dict]] = None,
        preprocess: Optional[Preprocess] = None,
        **kwargs: Any,
    ) -> "StyleTransferData":

        if any(param in kwargs for param in ("val_folder", "val_transform")):
            raise_not_supported("validation")

        if any(param in kwargs for param in ("test_folder", "test_transform")):
            raise_not_supported("test")

        preprocess = preprocess or cls.preprocess_cls(
            train_transform=train_transform,
            predict_transform=predict_transform,
        )

        return cls.from_data_source(
            DefaultDataSources.FOLDERS,
            train_data=train_folder,
            predict_data=predict_folder,
            preprocess=preprocess,
            **kwargs,
        )
コード例 #3
0
 def test_step(self, batch: Any, batch_idx: int) -> NoReturn:
     raise_not_supported("test")
コード例 #4
0
 def validation_step(self, batch: Any, batch_idx: int) -> NoReturn:
     raise_not_supported("validation")