def get_loaders( self, stage: str, epoch: int = None, ) -> "OrderedDict[str, DataLoader]": """ Returns loaders for the stage Args: stage: string with stage name epoch: epoch Returns: Dict of loaders """ data_params = dict(self.stages_config[stage]["data_params"]) loaders_params = { "train": { "collate_fn": PadSequence() }, "valid": { "collate_fn": PadSequence() }, } loaders = utils.get_loaders_from_params( get_datasets_fn=self.get_datasets, initial_seed=self.initial_seed, stage=stage, loaders_params=loaders_params, **data_params, ) return loaders
def get_loaders( self, stage: str, epoch: int = None, ) -> "OrderedDict[str, DataLoader]": """ Returns loaders for the stage Args: stage: string with stage name epoch: epoch Returns: Dict of loaders """ data_params = dict(self.stages_config[stage]["data_params"]) model_name = data_params["model_name"] tokenizer = AutoTokenizer.from_pretrained(model_name) collate_fn = DataCollatorForLanguageModeling(tokenizer) loaders_params = { "train": { "collate_fn": collate_fn }, "valid": { "collate_fn": collate_fn }, } loaders = utils.get_loaders_from_params( get_datasets_fn=self.get_datasets, initial_seed=self.initial_seed, stage=stage, loaders_params=loaders_params, **data_params, ) return loaders
def get_loaders( self, stage: str, epoch: int = None, ) -> "OrderedDict[str, DataLoader]": """Returns the loaders for a given stage.""" data_params = dict(self.stages_config[stage]["data_params"]) loaders = utils.get_loaders_from_params( get_datasets_fn=self.get_datasets, initial_seed=self.initial_seed, stage=stage, **data_params, ) return loaders
def get_loaders( self, stage: str, epoch: int = None, ) -> "OrderedDict[str, DataLoader]": """Returns the loaders for a given stage.""" if self._datasets is not None: self._loaders = utils.get_loaders_from_params( initial_seed=self.initial_seed, **self._datasets, ) if self._stage.startswith(STAGE_TRAIN_PREFIX): if len(self._loaders) == 1: self._valid_loader = list(self._loaders.keys())[0] warnings.warn( "Attention, there is only one dataloader - " + str(self._valid_loader) ) assert self._valid_loader in self._loaders, ( "The validation loader must be present " "in the loaders used during experiment." ) return self._loaders
def process_loaders( loaders: "OrderedDict[str, DataLoader]", datasets: Dict, stage: str, valid_loader: str, initial_seed: int, ) -> "Tuple[OrderedDict[str, DataLoader], str]": """Prepares loaders for a given stage.""" if datasets is not None: loaders = utils.get_loaders_from_params( initial_seed=initial_seed, **datasets, ) if not stage.startswith(settings.stage_infer_prefix): # train stage if len(loaders) == 1: valid_loader = list(loaders.keys())[0] warnings.warn("Attention, there is only one dataloader - " + str(valid_loader)) assert valid_loader in loaders, ( "The validation loader must be present " "in the loaders used during experiment.") return loaders, valid_loader