def __init__( self, preprocess: 'Preprocess', collate_fn: Callable, per_sample_transform: Union[Callable, _Sequential], per_batch_transform: Callable, stage: RunningStage, apply_per_sample_transform: bool = True, on_device: bool = False, ): super().__init__() self.preprocess = preprocess self.callback = ControlFlow(self.preprocess.callbacks) self.collate_fn = convert_to_modules(collate_fn) self.per_sample_transform = convert_to_modules(per_sample_transform) self.per_batch_transform = convert_to_modules(per_batch_transform) self.apply_per_sample_transform = apply_per_sample_transform self.stage = stage self.on_device = on_device extension = f"{'_on_device' if self.on_device else ''}" self._current_stage_context = CurrentRunningStageContext( stage, preprocess) self._per_sample_transform_context = CurrentFuncContext( f"per_sample_transform{extension}", preprocess) self._collate_context = CurrentFuncContext("collate", preprocess) self._per_batch_transform_context = CurrentFuncContext( f"per_batch_transform{extension}", preprocess)
def __init__( self, preprocess: 'Preprocess', pre_tensor_transform: Callable, to_tensor_transform: Callable, post_tensor_transform: Callable, stage: RunningStage, assert_contains_tensor: bool = False, ): super().__init__() self.preprocess = preprocess self.callback = ControlFlow(self.preprocess.callbacks) self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) self.to_tensor_transform = convert_to_modules(to_tensor_transform) self.post_tensor_transform = convert_to_modules(post_tensor_transform) self.stage = stage self.assert_contains_tensor = assert_contains_tensor self._current_stage_context = CurrentRunningStageContext(stage, preprocess, reset=False) self._pre_tensor_transform_context = CurrentFuncContext( "pre_tensor_transform", preprocess) self._to_tensor_transform_context = CurrentFuncContext( "to_tensor_transform", preprocess) self._post_tensor_transform_context = CurrentFuncContext( "post_tensor_transform", preprocess)
def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, ): super().__init__() # used to keep track of provided transforms self._train_collate_in_worker_from_transform: Optional[bool] = None self._val_collate_in_worker_from_transform: Optional[bool] = None self._predict_collate_in_worker_from_transform: Optional[bool] = None self._test_collate_in_worker_from_transform: Optional[bool] = None # store the transform before conversion to modules. self._train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) self._val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) self._test_transform = self._check_transforms(test_transform, RunningStage.TESTING) self._predict_transform = self._check_transforms( predict_transform, RunningStage.PREDICTING) self.train_transform = convert_to_modules(self._train_transform) self.val_transform = convert_to_modules(self._val_transform) self.test_transform = convert_to_modules(self._test_transform) self.predict_transform = convert_to_modules(self._predict_transform) self._callbacks: List[FlashCallback] = []
def __init__(self, pre_tensor_transform: Callable, to_tensor_transform: Callable, post_tensor_transform: Callable, assert_contains_tensor: bool = False): super().__init__() self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) self.to_tensor_transform = convert_to_modules(to_tensor_transform) self.post_tensor_transform = convert_to_modules(post_tensor_transform) self.assert_contains_tensor = assert_contains_tensor
def __init__( self, train_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, val_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, test_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, predict_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, ): super().__init__() self.train_transform = convert_to_modules(train_transform) self.val_transform = convert_to_modules(val_transform) self.test_transform = convert_to_modules(test_transform) self.predict_transform = convert_to_modules(predict_transform)
def __init__(self, uncollate_fn: Callable, per_batch_transform: Callable, per_sample_transform: Callable, save_fn: Optional[Callable] = None, save_per_sample: bool = False): super().__init__() self.uncollate_fn = convert_to_modules(uncollate_fn) self.per_batch_transform = convert_to_modules(per_batch_transform) self.per_sample_transform = convert_to_modules(per_sample_transform) self.save_fn = convert_to_modules(save_fn) self.save_per_sample = convert_to_modules(save_per_sample)
def __init__( self, collate_fn: Callable, per_sample_transform: Union[Callable, _Sequential], per_batch_transform: Callable, stage: Optional[RunningStage] = None, apply_per_sample_transform: bool = True, ): super().__init__() self.collate_fn = convert_to_modules(collate_fn) self.per_sample_transform = convert_to_modules(per_sample_transform) self.per_batch_transform = convert_to_modules(per_batch_transform) self.apply_per_sample_transform = apply_per_sample_transform self.stage = stage
def __init__( self, train_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, val_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, test_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, predict_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None, ): super().__init__() self.train_transform = convert_to_modules(train_transform) self.val_transform = convert_to_modules(val_transform) self.test_transform = convert_to_modules(test_transform) self.predict_transform = convert_to_modules(predict_transform) if not hasattr(self, "_skip_mutual_check"): self._skip_mutual_check = False
def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_sources: Optional[Dict[str, DataSource]] = None, default_data_source: Optional[str] = None, ): super().__init__() # resolve the default transforms train_transform = train_transform or self.default_train_transforms val_transform = val_transform or self.default_val_transforms test_transform = test_transform or self.default_test_transforms predict_transform = predict_transform or self.default_predict_transforms # used to keep track of provided transforms self._train_collate_in_worker_from_transform: Optional[bool] = None self._val_collate_in_worker_from_transform: Optional[bool] = None self._predict_collate_in_worker_from_transform: Optional[bool] = None self._test_collate_in_worker_from_transform: Optional[bool] = None # store the transform before conversion to modules. self.train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) self.val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) self.test_transform = self._check_transforms(test_transform, RunningStage.TESTING) self.predict_transform = self._check_transforms( predict_transform, RunningStage.PREDICTING) self._train_transform = convert_to_modules(self.train_transform) self._val_transform = convert_to_modules(self.val_transform) self._test_transform = convert_to_modules(self.test_transform) self._predict_transform = convert_to_modules(self.predict_transform) self._data_sources = data_sources self._default_data_source = default_data_source self._callbacks: List[FlashCallback] = [] self._default_collate: Callable = default_collate
def __init__(self, *args): super().__init__(*[convert_to_modules(arg) for arg in args])
def __init__(self, keys: Union[str, Sequence[str]], *args): super().__init__(*[convert_to_modules(arg) for arg in args]) if isinstance(keys, str): keys = [keys] self.keys = keys