def _get_stages_config(self, stages_config: Dict): stages_defaults = {} stages_config_out = OrderedDict() for key in self.STAGE_KEYWORDS: if key == "stage_params": # backward compatibility stages_defaults[key] = utils.merge_dicts( deepcopy(stages_config.get("state_params", {})), deepcopy(stages_config.get(key, {})), ) else: stages_defaults[key] = deepcopy(stages_config.get(key, {})) for stage in stages_config: if (stage in self.STAGE_KEYWORDS or stage == "state_params" or stages_config.get(stage) is None): continue stages_config_out[stage] = {} for key2 in self.STAGE_KEYWORDS: if key2 == "stage_params": # backward compatibility stages_config_out[stage][key2] = utils.merge_dicts( deepcopy(stages_defaults.get("state_params", {})), deepcopy(stages_defaults.get(key2, {})), deepcopy(stages_config[stage].get("state_params", {})), deepcopy(stages_config[stage].get(key2, {})), ) else: stages_config_out[stage][key2] = utils.merge_dicts( deepcopy(stages_defaults.get(key2, {})), deepcopy(stages_config[stage].get(key2, {})), ) return stages_config_out
def __init__(self, config: Dict): """ Args: config (dict): dictionary with parameters """ self._config: Dict = deepcopy(config) self._initial_seed: int = self._config.get("args", {}).get("seed", 42) self._verbose: bool = self._config.get("args", {}).get("verbose", False) self._check_time: bool = self._config.get("args", {}).get("timeit", False) self._check_run: bool = self._config.get("args", {}).get("check", False) self._overfit: bool = self._config.get("args", {}).get("overfit", False) self.__prepare_logdir() self._config["stages"]["stage_params"] = utils.merge_dicts( deepcopy(self._config["stages"].get( "state_params", {})), # saved for backward compatibility deepcopy(self._config["stages"].get("stage_params", {})), deepcopy(self._config.get("args", {})), {"logdir": self._logdir}, ) self.stages_config: Dict = self._get_stages_config( self._config["stages"])
def __init__(self, config: Dict): self._config = deepcopy(config) self._initial_seed = self._config.get("args", {}).get("seed", 42) self.__prepare_logdir() self._config["stages"]["state_params"] = utils.merge_dicts( deepcopy(self._config["stages"].get("state_params", {})), deepcopy(self._config.get("args", {})), {"logdir": self._logdir} ) self.stages_config = self._get_stages_config(self._config["stages"])
def __init__(self, config: Dict): """ Args: config (dict): dictionary of parameters """ self._config = deepcopy(config) self._initial_seed = self._config.get("args", {}).get("seed", 42) self._verbose = self._config.get("args", {}).get("verbose", False) self._check_run = self._config.get("args", {}).get("check", False) self.__prepare_logdir() self._config["stages"]["state_params"] = utils.merge_dicts( deepcopy(self._config["stages"].get("state_params", {})), deepcopy(self._config.get("args", {})), {"logdir": self._logdir}) self.stages_config = self._get_stages_config(self._config["stages"])
def _get_stages_config(self, stages_config: Dict): stages_defaults = {} stages_config_out = OrderedDict() for key in self.STAGE_KEYWORDS: stages_defaults[key] = deepcopy(stages_config.get(key, {})) for stage in stages_config: if (stage in self.STAGE_KEYWORDS or stages_config.get(stage) is None): continue stages_config_out[stage] = {} for key in self.STAGE_KEYWORDS: stages_config_out[stage][key] = utils.merge_dicts( deepcopy(stages_defaults.get(key, {})), deepcopy(stages_config[stage].get(key, {})), ) return stages_config_out
def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]": """Returns the loaders for a given stage""" data_params = dict(self.stages_config[stage]["data_params"]) batch_size = data_params.pop("batch_size", 1) num_workers = data_params.pop("num_workers") drop_last = data_params.pop("drop_last", False) per_gpu_scaling = data_params.pop("per_gpu_scaling", False) distributed_rank = self.distributed_params.get("rank", -1) distributed = distributed_rank > -1 datasets = self.get_datasets(stage=stage, **data_params) overridden_loaders_params = data_params.pop("loaders_params", {}) assert isinstance(overridden_loaders_params, dict), \ f"{overridden_loaders_params} should be Dict" loaders = OrderedDict() for name, ds_ in datasets.items(): assert isinstance(ds_, (Dataset, dict)), \ f"{ds_} should be Dataset or Dict" overridden_loader_params = overridden_loaders_params.pop(name, {}) assert isinstance(overridden_loader_params, dict), \ f"{overridden_loader_params} should be Dict" batch_size = overridden_loader_params.pop("batch_size", batch_size) num_workers = overridden_loader_params.\ pop("num_workers", num_workers) if per_gpu_scaling and not distributed: num_gpus = max(1, torch.cuda.device_count()) batch_size *= num_gpus num_workers *= num_gpus loader_params = { "batch_size": batch_size, "num_workers": num_workers, "pin_memory": torch.cuda.is_available(), "drop_last": drop_last, **overridden_loader_params } if isinstance(ds_, Dataset): loader_params["dataset"] = ds_ elif isinstance(ds_, dict): assert "dataset" in ds_, \ "You need to specify dataset for dataloader" loader_params = utils.merge_dicts(ds_, loader_params) else: raise NotImplementedError if distributed: sampler = loader_params.get("sampler") if sampler is not None: assert isinstance(sampler, DistributedSampler) else: loader_params["sampler"] = DistributedSampler( dataset=loader_params["dataset"]) loader_params["shuffle"] = (name.startswith("train") and loader_params.get("sampler") is None) if "batch_sampler" in loader_params: if distributed: raise ValueError("batch_sampler option is mutually " "exclusive with distributed") for k in ("batch_size", "shuffle", "sampler", "drop_last"): loader_params.pop(k, None) if "worker_init_fn" not in loader_params: loader_params["worker_init_fn"] = \ lambda x: utils.set_global_seed(self.initial_seed + x) loaders[name] = DataLoader(**loader_params) return loaders