Esempio n. 1
0
    def __init__(
        self,
        preprocess: 'Preprocess',
        collate_fn: Callable,
        per_sample_transform: Union[Callable, _Sequential],
        per_batch_transform: Callable,
        stage: RunningStage,
        apply_per_sample_transform: bool = True,
        on_device: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.collate_fn = convert_to_modules(collate_fn)
        self.per_sample_transform = convert_to_modules(per_sample_transform)
        self.per_batch_transform = convert_to_modules(per_batch_transform)
        self.apply_per_sample_transform = apply_per_sample_transform
        self.stage = stage
        self.on_device = on_device

        extension = f"{'_on_device' if self.on_device else ''}"
        self._current_stage_context = CurrentRunningStageContext(
            stage, preprocess)
        self._per_sample_transform_context = CurrentFuncContext(
            f"per_sample_transform{extension}", preprocess)
        self._collate_context = CurrentFuncContext("collate", preprocess)
        self._per_batch_transform_context = CurrentFuncContext(
            f"per_batch_transform{extension}", preprocess)
Esempio n. 2
0
    def __init__(
        self,
        preprocess: 'Preprocess',
        pre_tensor_transform: Callable,
        to_tensor_transform: Callable,
        post_tensor_transform: Callable,
        stage: RunningStage,
        assert_contains_tensor: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)
        self.post_tensor_transform = convert_to_modules(post_tensor_transform)
        self.stage = stage
        self.assert_contains_tensor = assert_contains_tensor

        self._current_stage_context = CurrentRunningStageContext(stage,
                                                                 preprocess,
                                                                 reset=False)
        self._pre_tensor_transform_context = CurrentFuncContext(
            "pre_tensor_transform", preprocess)
        self._to_tensor_transform_context = CurrentFuncContext(
            "to_tensor_transform", preprocess)
        self._post_tensor_transform_context = CurrentFuncContext(
            "post_tensor_transform", preprocess)
Esempio n. 3
0
    def __init__(
        self,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
    ):
        super().__init__()

        # used to keep track of provided transforms
        self._train_collate_in_worker_from_transform: Optional[bool] = None
        self._val_collate_in_worker_from_transform: Optional[bool] = None
        self._predict_collate_in_worker_from_transform: Optional[bool] = None
        self._test_collate_in_worker_from_transform: Optional[bool] = None

        # store the transform before conversion to modules.
        self._train_transform = self._check_transforms(train_transform,
                                                       RunningStage.TRAINING)
        self._val_transform = self._check_transforms(val_transform,
                                                     RunningStage.VALIDATING)
        self._test_transform = self._check_transforms(test_transform,
                                                      RunningStage.TESTING)
        self._predict_transform = self._check_transforms(
            predict_transform, RunningStage.PREDICTING)

        self.train_transform = convert_to_modules(self._train_transform)
        self.val_transform = convert_to_modules(self._val_transform)
        self.test_transform = convert_to_modules(self._test_transform)
        self.predict_transform = convert_to_modules(self._predict_transform)

        self._callbacks: List[FlashCallback] = []
Esempio n. 4
0
    def __init__(self,
                 pre_tensor_transform: Callable,
                 to_tensor_transform: Callable,
                 post_tensor_transform: Callable,
                 assert_contains_tensor: bool = False):
        super().__init__()

        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)
        self.post_tensor_transform = convert_to_modules(post_tensor_transform)
        self.assert_contains_tensor = assert_contains_tensor
Esempio n. 5
0
 def __init__(
     self,
     train_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
     val_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
     test_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
     predict_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
 ):
     super().__init__()
     self.train_transform = convert_to_modules(train_transform)
     self.val_transform = convert_to_modules(val_transform)
     self.test_transform = convert_to_modules(test_transform)
     self.predict_transform = convert_to_modules(predict_transform)
Esempio n. 6
0
 def __init__(self,
              uncollate_fn: Callable,
              per_batch_transform: Callable,
              per_sample_transform: Callable,
              save_fn: Optional[Callable] = None,
              save_per_sample: bool = False):
     super().__init__()
     self.uncollate_fn = convert_to_modules(uncollate_fn)
     self.per_batch_transform = convert_to_modules(per_batch_transform)
     self.per_sample_transform = convert_to_modules(per_sample_transform)
     self.save_fn = convert_to_modules(save_fn)
     self.save_per_sample = convert_to_modules(save_per_sample)
Esempio n. 7
0
 def __init__(
     self,
     collate_fn: Callable,
     per_sample_transform: Union[Callable, _Sequential],
     per_batch_transform: Callable,
     stage: Optional[RunningStage] = None,
     apply_per_sample_transform: bool = True,
 ):
     super().__init__()
     self.collate_fn = convert_to_modules(collate_fn)
     self.per_sample_transform = convert_to_modules(per_sample_transform)
     self.per_batch_transform = convert_to_modules(per_batch_transform)
     self.apply_per_sample_transform = apply_per_sample_transform
     self.stage = stage
    def __init__(
        self,
        train_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
        val_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
        test_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
        predict_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
    ):
        super().__init__()
        self.train_transform = convert_to_modules(train_transform)
        self.val_transform = convert_to_modules(val_transform)
        self.test_transform = convert_to_modules(test_transform)
        self.predict_transform = convert_to_modules(predict_transform)

        if not hasattr(self, "_skip_mutual_check"):
            self._skip_mutual_check = False
Esempio n. 9
0
    def __init__(
        self,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
        data_sources: Optional[Dict[str, DataSource]] = None,
        default_data_source: Optional[str] = None,
    ):
        super().__init__()

        # resolve the default transforms
        train_transform = train_transform or self.default_train_transforms
        val_transform = val_transform or self.default_val_transforms
        test_transform = test_transform or self.default_test_transforms
        predict_transform = predict_transform or self.default_predict_transforms

        # used to keep track of provided transforms
        self._train_collate_in_worker_from_transform: Optional[bool] = None
        self._val_collate_in_worker_from_transform: Optional[bool] = None
        self._predict_collate_in_worker_from_transform: Optional[bool] = None
        self._test_collate_in_worker_from_transform: Optional[bool] = None

        # store the transform before conversion to modules.
        self.train_transform = self._check_transforms(train_transform,
                                                      RunningStage.TRAINING)
        self.val_transform = self._check_transforms(val_transform,
                                                    RunningStage.VALIDATING)
        self.test_transform = self._check_transforms(test_transform,
                                                     RunningStage.TESTING)
        self.predict_transform = self._check_transforms(
            predict_transform, RunningStage.PREDICTING)

        self._train_transform = convert_to_modules(self.train_transform)
        self._val_transform = convert_to_modules(self.val_transform)
        self._test_transform = convert_to_modules(self.test_transform)
        self._predict_transform = convert_to_modules(self.predict_transform)

        self._data_sources = data_sources
        self._default_data_source = default_data_source

        self._callbacks: List[FlashCallback] = []
        self._default_collate: Callable = default_collate
Esempio n. 10
0
 def __init__(self, *args):
     super().__init__(*[convert_to_modules(arg) for arg in args])
Esempio n. 11
0
 def __init__(self, keys: Union[str, Sequence[str]], *args):
     super().__init__(*[convert_to_modules(arg) for arg in args])
     if isinstance(keys, str):
         keys = [keys]
     self.keys = keys