예제 #1
0
    def _get_stages_config(self, stages_config: Dict):
        stages_defaults = {}
        stages_config_out = OrderedDict()
        for key in self.STAGE_KEYWORDS:
            if key == "stage_params":
                # backward compatibility
                stages_defaults[key] = utils.merge_dicts(
                    deepcopy(stages_config.get("state_params", {})),
                    deepcopy(stages_config.get(key, {})),
                )
            else:
                stages_defaults[key] = deepcopy(stages_config.get(key, {}))
        for stage in stages_config:
            if (stage in self.STAGE_KEYWORDS or stage == "state_params"
                    or stages_config.get(stage) is None):
                continue
            stages_config_out[stage] = {}
            for key2 in self.STAGE_KEYWORDS:
                if key2 == "stage_params":
                    # backward compatibility
                    stages_config_out[stage][key2] = utils.merge_dicts(
                        deepcopy(stages_defaults.get("state_params", {})),
                        deepcopy(stages_defaults.get(key2, {})),
                        deepcopy(stages_config[stage].get("state_params", {})),
                        deepcopy(stages_config[stage].get(key2, {})),
                    )
                else:
                    stages_config_out[stage][key2] = utils.merge_dicts(
                        deepcopy(stages_defaults.get(key2, {})),
                        deepcopy(stages_config[stage].get(key2, {})),
                    )

        return stages_config_out
예제 #2
0
    def __init__(self, config: Dict):
        """
        Args:
            config (dict): dictionary with parameters
        """
        self._config: Dict = deepcopy(config)
        self._initial_seed: int = self._config.get("args", {}).get("seed", 42)
        self._verbose: bool = self._config.get("args",
                                               {}).get("verbose", False)
        self._check_time: bool = self._config.get("args",
                                                  {}).get("timeit", False)
        self._check_run: bool = self._config.get("args",
                                                 {}).get("check", False)
        self._overfit: bool = self._config.get("args",
                                               {}).get("overfit", False)

        self.__prepare_logdir()

        self._config["stages"]["stage_params"] = utils.merge_dicts(
            deepcopy(self._config["stages"].get(
                "state_params", {})),  # saved for backward compatibility
            deepcopy(self._config["stages"].get("stage_params", {})),
            deepcopy(self._config.get("args", {})),
            {"logdir": self._logdir},
        )
        self.stages_config: Dict = self._get_stages_config(
            self._config["stages"])
예제 #3
0
    def __init__(self, config: Dict):
        self._config = deepcopy(config)
        self._initial_seed = self._config.get("args", {}).get("seed", 42)
        self.__prepare_logdir()

        self._config["stages"]["state_params"] = utils.merge_dicts(
            deepcopy(self._config["stages"].get("state_params", {})),
            deepcopy(self._config.get("args", {})), {"logdir": self._logdir}
        )
        self.stages_config = self._get_stages_config(self._config["stages"])
예제 #4
0
    def __init__(self, config: Dict):
        """
        Args:
            config (dict): dictionary of parameters
        """
        self._config = deepcopy(config)
        self._initial_seed = self._config.get("args", {}).get("seed", 42)
        self._verbose = self._config.get("args", {}).get("verbose", False)
        self._check_run = self._config.get("args", {}).get("check", False)
        self.__prepare_logdir()

        self._config["stages"]["state_params"] = utils.merge_dicts(
            deepcopy(self._config["stages"].get("state_params", {})),
            deepcopy(self._config.get("args", {})), {"logdir": self._logdir})
        self.stages_config = self._get_stages_config(self._config["stages"])
예제 #5
0
    def _get_stages_config(self, stages_config: Dict):
        stages_defaults = {}
        stages_config_out = OrderedDict()
        for key in self.STAGE_KEYWORDS:
            stages_defaults[key] = deepcopy(stages_config.get(key, {}))
        for stage in stages_config:
            if (stage in self.STAGE_KEYWORDS
                    or stages_config.get(stage) is None):
                continue
            stages_config_out[stage] = {}
            for key in self.STAGE_KEYWORDS:
                stages_config_out[stage][key] = utils.merge_dicts(
                    deepcopy(stages_defaults.get(key, {})),
                    deepcopy(stages_config[stage].get(key, {})),
                )

        return stages_config_out
예제 #6
0
파일: config.py 프로젝트: metya/catalyst
    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        """Returns the loaders for a given stage"""
        data_params = dict(self.stages_config[stage]["data_params"])

        batch_size = data_params.pop("batch_size", 1)
        num_workers = data_params.pop("num_workers")
        drop_last = data_params.pop("drop_last", False)
        per_gpu_scaling = data_params.pop("per_gpu_scaling", False)
        distributed_rank = self.distributed_params.get("rank", -1)
        distributed = distributed_rank > -1

        datasets = self.get_datasets(stage=stage, **data_params)

        overridden_loaders_params = data_params.pop("loaders_params", {})
        assert isinstance(overridden_loaders_params, dict), \
            f"{overridden_loaders_params} should be Dict"

        loaders = OrderedDict()
        for name, ds_ in datasets.items():
            assert isinstance(ds_, (Dataset, dict)), \
                f"{ds_} should be Dataset or Dict"

            overridden_loader_params = overridden_loaders_params.pop(name, {})
            assert isinstance(overridden_loader_params, dict), \
                f"{overridden_loader_params} should be Dict"

            batch_size = overridden_loader_params.pop("batch_size", batch_size)
            num_workers = overridden_loader_params.\
                pop("num_workers", num_workers)

            if per_gpu_scaling and not distributed:
                num_gpus = max(1, torch.cuda.device_count())
                batch_size *= num_gpus
                num_workers *= num_gpus

            loader_params = {
                "batch_size": batch_size,
                "num_workers": num_workers,
                "pin_memory": torch.cuda.is_available(),
                "drop_last": drop_last,
                **overridden_loader_params
            }

            if isinstance(ds_, Dataset):
                loader_params["dataset"] = ds_
            elif isinstance(ds_, dict):
                assert "dataset" in ds_, \
                    "You need to specify dataset for dataloader"
                loader_params = utils.merge_dicts(ds_, loader_params)
            else:
                raise NotImplementedError

            if distributed:
                sampler = loader_params.get("sampler")
                if sampler is not None:
                    assert isinstance(sampler, DistributedSampler)
                else:
                    loader_params["sampler"] = DistributedSampler(
                        dataset=loader_params["dataset"])

            loader_params["shuffle"] = (name.startswith("train") and
                                        loader_params.get("sampler") is None)

            if "batch_sampler" in loader_params:
                if distributed:
                    raise ValueError("batch_sampler option is mutually "
                                     "exclusive with distributed")

                for k in ("batch_size", "shuffle", "sampler", "drop_last"):
                    loader_params.pop(k, None)

            if "worker_init_fn" not in loader_params:
                loader_params["worker_init_fn"] = \
                    lambda x: utils.set_global_seed(self.initial_seed + x)

            loaders[name] = DataLoader(**loader_params)

        return loaders