Beispiel #1
0
    def get_loaders(self, stage: str):
        loaders_params = dict(self._stage_config[stage]["loaders"])
        
        loaders_params["train"]["collate_fn"] = self.get_collate_fn()
        loaders_params["valid"]["collate_fn"] = self.get_collate_fn()

        loaders = get_loaders_from_params(
            datasets=self.get_datasets(stage),
            initial_seed=self.seed,
            loaders_params=loaders_params,
        )
        return loaders
Beispiel #2
0
    def get_loaders(self, stage: str) -> Dict[str, DataLoader]:
        """
        Returns loaders for a given stage.

        Args:
            stage: stage name

        Returns:
            Dict: loaders objects

        """
        loaders_params = self._config.stages[stage].loaders
        loaders_params = OmegaConf.to_container(loaders_params, resolve=True)
        loaders_params.pop("datasets", None)

        loaders = get_loaders_from_params(
            datasets=self.get_datasets(stage=stage), initial_seed=self.seed, **loaders_params
        )
        return loaders
Beispiel #3
0
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        """
        Returns loaders for a given stage.

        Args:
            stage: stage name

        Returns:
            Dict: loaders objects

        """
        loaders_params = dict(self._stage_config[stage]["loaders"])
        loaders = get_loaders_from_params(
            datasets_fn=partial(self.get_datasets, stage=stage),
            initial_seed=self.seed,
            stage=stage,
            **loaders_params,
        )
        return loaders
Beispiel #4
0
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        """
        Returns loaders for a given stage.

        Args:
            stage: stage name

        Returns:
            Dict: loaders objects
        """
        loaders_params = deepcopy(self._stage_config[stage]["loaders"])
        loaders = self._get_loaders_from_params(**loaders_params)
        if loaders is None:
            #  config is parsed manyally in `get_datasets` and `get_samplers` methods
            loaders_params.pop("datasets", None)
            loaders_params.pop("samplers", None)

            loaders = get_loaders_from_params(
                datasets=self.get_datasets(stage=stage),
                samplers=self.get_samplers(stage=stage),
                initial_seed=self.seed,
                **loaders_params,
            )
        return loaders
Beispiel #5
0
def _process_loaders(
    loaders: "OrderedDict[str, DataLoader]", initial_seed: int
) -> "OrderedDict[str, DataLoader]":
    if not isinstance(loaders[list(loaders.keys())[0]], (DataLoader, ILoaderWrapper)):
        loaders = get_loaders_from_params(initial_seed=initial_seed, **loaders)
    return loaders