def autogenerate_dataset( cls, data: Any, running_stage: RunningStage, whole_data_load_fn: Optional[Callable] = None, per_sample_load_fn: Optional[Callable] = None, data_pipeline: Optional[DataPipeline] = None, ) -> AutoDataset: """ This function is used to generate an ``AutoDataset`` from a ``DataPipeline`` if provided or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly """ if whole_data_load_fn is None: whole_data_load_fn = getattr( cls.preprocess_cls, DataPipeline._resolve_function_hierarchy( 'load_data', cls.preprocess_cls, running_stage, Preprocess)) if per_sample_load_fn is None: per_sample_load_fn = getattr( cls.preprocess_cls, DataPipeline._resolve_function_hierarchy( 'load_sample', cls.preprocess_cls, running_stage, Preprocess)) return AutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)
def generate_dataset( self, data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: is_none = data is None if isinstance(data, Sequence): is_none = data[0] is None if not is_none: from flash.data.data_pipeline import DataPipeline mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self, DataPipeline._resolve_function_hierarchy( "load_data", self, running_stage, DataSource, ) ) parameters = signature(load_data).parameters if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before data = load_data(data, mock_dataset) else: data = load_data(data) if has_len(data): dataset = AutoDataset(data, self, running_stage) else: dataset = IterableAutoDataset(data, self, running_stage) dataset.__dict__.update(mock_dataset.metadata) return dataset
def autogenerate_dataset( cls, data: Any, running_stage: RunningStage, whole_data_load_fn: Optional[Callable] = None, per_sample_load_fn: Optional[Callable] = None, data_pipeline: Optional[DataPipeline] = None, use_iterable_auto_dataset: bool = False, ) -> BaseAutoDataset: """ This function is used to generate an ``BaseAutoDataset`` from a ``DataPipeline`` if provided or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly """ preprocess = getattr(data_pipeline, '_preprocess_pipeline', None) if whole_data_load_fn is None: whole_data_load_fn = getattr( preprocess, DataPipeline._resolve_function_hierarchy( 'load_data', preprocess, running_stage, Preprocess)) if per_sample_load_fn is None: per_sample_load_fn = getattr( preprocess, DataPipeline._resolve_function_hierarchy( 'load_sample', preprocess, running_stage, Preprocess)) if use_iterable_auto_dataset: return IterableAutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage) return BaseAutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)
def generate_dataset( self, data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: """Generate a single dataset with the given input to :meth:`~flash.data.data_source.DataSource.load_data` for the given ``running_stage``. Args: data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the dataset. running_stage: The running_stage for this dataset. Returns: The constructed :class:`~flash.data.auto_dataset.BaseAutoDataset`. """ is_none = data is None if isinstance(data, Sequence): is_none = data[0] is None if not is_none: from flash.data.data_pipeline import DataPipeline mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self, DataPipeline._resolve_function_hierarchy( "load_data", self, running_stage, DataSource, )) parameters = signature(load_data).parameters if len( parameters ) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before data = load_data(data, mock_dataset) else: data = load_data(data) if has_len(data): dataset = AutoDataset(data, self, running_stage) else: dataset = IterableAutoDataset(data, self, running_stage) dataset.__dict__.update(mock_dataset.metadata) return dataset
def running_stage(self, running_stage: RunningStage) -> None: from flash.data.data_pipeline import DataPipeline # noqa F811 from flash.data.data_source import DataSource # noqa F811 # TODO: something better than this self._running_stage = running_stage self._load_sample_context = CurrentRunningStageFuncContext( self.running_stage, "load_sample", self.data_source) self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self.data_source, DataPipeline._resolve_function_hierarchy( 'load_sample', self.data_source, self.running_stage, DataSource, ))
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomPreprocess(Preprocess): def load_data(self, *_, **__): pass def test_load_data(self, *_, **__): pass def predict_load_data(self, *_, **__): pass def predict_load_sample(self, *_, **__): pass def val_load_sample(self, *_, **__): pass def val_pre_tensor_transform(self, *_, **__): pass def predict_to_tensor_transform(self, *_, **__): pass def train_post_tensor_transform(self, *_, **__): pass def test_collate(self, *_, **__): pass def val_per_sample_transform_on_device(self, *_, **__): pass def train_per_batch_transform_on_device(self, *_, **__): pass def test_per_batch_transform_on_device(self, *_, **__): pass preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) train_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } val_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } test_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } predict_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } # load_data assert train_func_names["load_data"] == "load_data" assert val_func_names["load_data"] == "load_data" assert test_func_names["load_data"] == "test_load_data" assert predict_func_names["load_data"] == "predict_load_data" # load_sample assert train_func_names["load_sample"] == "load_sample" assert val_func_names["load_sample"] == "val_load_sample" assert test_func_names["load_sample"] == "load_sample" assert predict_func_names["load_sample"] == "predict_load_sample" # pre_tensor_transform assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform" assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform" # to_tensor_transform assert train_func_names["to_tensor_transform"] == "to_tensor_transform" assert val_func_names["to_tensor_transform"] == "to_tensor_transform" assert test_func_names["to_tensor_transform"] == "to_tensor_transform" assert predict_func_names["to_tensor_transform"] == "predict_to_tensor_transform" # post_tensor_transform assert train_func_names["post_tensor_transform"] == "train_post_tensor_transform" assert val_func_names["post_tensor_transform"] == "post_tensor_transform" assert test_func_names["post_tensor_transform"] == "post_tensor_transform" assert predict_func_names["post_tensor_transform"] == "post_tensor_transform" # collate assert train_func_names["collate"] == "collate" assert val_func_names["collate"] == "collate" assert test_func_names["collate"] == "test_collate" assert predict_func_names["collate"] == "collate" # per_sample_transform_on_device assert train_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" assert val_func_names["per_sample_transform_on_device"] == "val_per_sample_transform_on_device" assert test_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" assert predict_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" # per_batch_transform_on_device assert train_func_names["per_batch_transform_on_device"] == "train_per_batch_transform_on_device" assert val_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device" assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" train_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) val_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING) test_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING) predict_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING) _seq = train_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform assert train_worker_preprocessor.collate_fn.func == default_collate assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = val_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert val_worker_preprocessor.collate_fn.func == data_pipeline._identity assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = test_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = predict_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert predict_worker_preprocessor.collate_fn.func == default_collate assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform