예제 #1
0
    def __init__(
        self,
        preprocess: "Preprocess",
        pre_tensor_transform: Optional[Callable],
        to_tensor_transform: Optional[Callable],
        post_tensor_transform: Callable,
        stage: RunningStage,
        assert_contains_tensor: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)
        self.post_tensor_transform = convert_to_modules(post_tensor_transform)
        self.stage = stage
        self.assert_contains_tensor = assert_contains_tensor

        self._current_stage_context = CurrentRunningStageContext(stage,
                                                                 preprocess,
                                                                 reset=False)
        self._pre_tensor_transform_context = CurrentFuncContext(
            "pre_tensor_transform", preprocess)
        self._to_tensor_transform_context = CurrentFuncContext(
            "to_tensor_transform", preprocess)
        self._post_tensor_transform_context = CurrentFuncContext(
            "post_tensor_transform", preprocess)
예제 #2
0
def merge_transforms(
    base_transforms: Dict[str, Callable],
    additional_transforms: Dict[str, Callable],
) -> Dict[str, Callable]:
    """Utility function to merge two transform dictionaries. For each hook, the ``additional_transforms`` will be be
    called after the ``base_transforms``.

    Args:
        base_transforms: The base transforms dictionary.
        additional_transforms: The dictionary of additional transforms to be appended to the ``base_transforms``.

    Returns:
        The new dictionary of transforms.
    """
    transforms = {}
    for hook in _PREPROCESS_FUNCS:
        if hook in base_transforms and hook in additional_transforms:
            transforms[hook] = nn.Sequential(
                convert_to_modules(base_transforms[hook]),
                convert_to_modules(additional_transforms[hook]),
            )
        elif hook in base_transforms:
            transforms[hook] = base_transforms[hook]
        elif hook in additional_transforms:
            transforms[hook] = additional_transforms[hook]
    return transforms
예제 #3
0
    def __init__(
        self,
        preprocess: "Preprocess",
        collate_fn: Callable,
        per_sample_transform: Union[Callable, _Sequential],
        per_batch_transform: Callable,
        stage: RunningStage,
        apply_per_sample_transform: bool = True,
        on_device: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.collate_fn = convert_to_modules(collate_fn)
        self.per_sample_transform = convert_to_modules(per_sample_transform)
        self.per_batch_transform = convert_to_modules(per_batch_transform)
        self.apply_per_sample_transform = apply_per_sample_transform
        self.stage = stage
        self.on_device = on_device

        extension = f"{'_on_device' if self.on_device else ''}"
        self._current_stage_context = CurrentRunningStageContext(
            stage, preprocess)
        self._per_sample_transform_context = CurrentFuncContext(
            f"per_sample_transform{extension}", preprocess)
        self._collate_context = CurrentFuncContext("collate", preprocess)
        self._per_batch_transform_context = CurrentFuncContext(
            f"per_batch_transform{extension}", preprocess)
예제 #4
0
    def __init__(
        self,
        train_transform: Optional[Union[Callable, List,
                                        Dict[str, Callable]]] = None,
        val_transform: Optional[Union[Callable, List, Dict[str,
                                                           Callable]]] = None,
        test_transform: Optional[Union[Callable, List, Dict[str,
                                                            Callable]]] = None,
        predict_transform: Optional[Union[Callable, List,
                                          Dict[str, Callable]]] = None,
        data_sources: Optional[Dict[str, "DataSource"]] = None,
        deserializer: Optional["Deserializer"] = None,
        default_data_source: Optional[str] = None,
    ):
        super().__init__()

        # resolve the default transforms
        train_transform = train_transform or self._resolve_transforms(
            RunningStage.TRAINING)
        val_transform = val_transform or self._resolve_transforms(
            RunningStage.VALIDATING)
        test_transform = test_transform or self._resolve_transforms(
            RunningStage.TESTING)
        predict_transform = predict_transform or self._resolve_transforms(
            RunningStage.PREDICTING)

        # used to keep track of provided transforms
        self._train_collate_in_worker_from_transform: Optional[bool] = None
        self._val_collate_in_worker_from_transform: Optional[bool] = None
        self._predict_collate_in_worker_from_transform: Optional[bool] = None
        self._test_collate_in_worker_from_transform: Optional[bool] = None

        # store the transform before conversion to modules.
        self.train_transform = self._check_transforms(train_transform,
                                                      RunningStage.TRAINING)
        self.val_transform = self._check_transforms(val_transform,
                                                    RunningStage.VALIDATING)
        self.test_transform = self._check_transforms(test_transform,
                                                     RunningStage.TESTING)
        self.predict_transform = self._check_transforms(
            predict_transform, RunningStage.PREDICTING)

        self._train_transform = convert_to_modules(self.train_transform)
        self._val_transform = convert_to_modules(self.val_transform)
        self._test_transform = convert_to_modules(self.test_transform)
        self._predict_transform = convert_to_modules(self.predict_transform)

        if DefaultDataSources.DATASETS not in data_sources:
            data_sources[DefaultDataSources.DATASETS] = DatasetDataSource()

        self._data_sources = data_sources
        self._deserializer = deserializer
        self._default_data_source = default_data_source
        self._callbacks: List[FlashCallback] = []
        self._default_collate: Callable = default_collate
예제 #5
0
    def __init__(
        self,
        deserializer: "Deserializer",
        preprocess: "Preprocess",
        pre_tensor_transform: Callable,
        to_tensor_transform: Callable,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.deserializer = convert_to_modules(deserializer)
        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)

        self._current_stage_context = CurrentRunningStageContext(
            RunningStage.PREDICTING, preprocess, reset=False)
        self._pre_tensor_transform_context = CurrentFuncContext(
            "pre_tensor_transform", preprocess)
        self._to_tensor_transform_context = CurrentFuncContext(
            "to_tensor_transform", preprocess)
예제 #6
0
 def __init__(self,
              uncollate_fn: Callable,
              per_batch_transform: Callable,
              per_sample_transform: Callable,
              serializer: Optional[Callable],
              save_fn: Optional[Callable] = None,
              save_per_sample: bool = False):
     super().__init__()
     self.uncollate_fn = convert_to_modules(uncollate_fn)
     self.per_batch_transform = convert_to_modules(per_batch_transform)
     self.per_sample_transform = convert_to_modules(per_sample_transform)
     self.serializer = convert_to_modules(serializer)
     self.save_fn = convert_to_modules(save_fn)
     self.save_per_sample = convert_to_modules(save_per_sample)
예제 #7
0
     },
     {
         "to_tensor_transform": _MOCK_TRANSFORM,
         "post_tensor_transform": _MOCK_TRANSFORM
     },
 ),
 (
     {
         "to_tensor_transform": _MOCK_TRANSFORM
     },
     {
         "to_tensor_transform": _MOCK_TRANSFORM
     },
     {
         "to_tensor_transform": nn.Sequential(
             convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM)
         )
     },
 ),
 (
     {
         "to_tensor_transform": _MOCK_TRANSFORM
     },
     {
         "to_tensor_transform": _MOCK_TRANSFORM,
         "post_tensor_transform": _MOCK_TRANSFORM
     },
     {
         "to_tensor_transform": nn.Sequential(
             convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM)
         ),
예제 #8
0
 def __init__(
     self,
     serializer: "Serializer",
 ):
     super().__init__()
     self.serializer = convert_to_modules(serializer)
예제 #9
0
 def __init__(self, *args):
     super().__init__(*[convert_to_modules(arg) for arg in args])
예제 #10
0
 def __init__(self, keys: Union[str, Sequence[str]], *args):
     super().__init__(*[convert_to_modules(arg) for arg in args])
     if isinstance(keys, str):
         keys = [keys]
     self.keys = keys
예제 #11
0
     },
     {
         "to_tensor_transform": _MOCK_TRANSFORM,
         "post_tensor_transform": _MOCK_TRANSFORM
     },
 ),
 (
     {
         "to_tensor_transform": _MOCK_TRANSFORM
     },
     {
         "to_tensor_transform": _MOCK_TRANSFORM
     },
     {
         "to_tensor_transform":
         nn.Sequential(convert_to_modules(_MOCK_TRANSFORM),
                       convert_to_modules(_MOCK_TRANSFORM))
     },
 ),
 (
     {
         "to_tensor_transform": _MOCK_TRANSFORM
     },
     {
         "to_tensor_transform": _MOCK_TRANSFORM,
         "post_tensor_transform": _MOCK_TRANSFORM
     },
     {
         "to_tensor_transform":
         nn.Sequential(convert_to_modules(_MOCK_TRANSFORM),
                       convert_to_modules(_MOCK_TRANSFORM)),