class DataLoaderCollectionConfig(cfg.BaseConfig): transform: TransformConfig = cfg.Option(required=True, type=TransformConfig) train: DatasetConfig = cfg.Option(required=True, type=DatasetConfig) test: DatasetConfig = cfg.Option(required=True, type=DatasetConfig) valid: Optional[DatasetConfig] = cfg.Option(nullable=True, type=DatasetConfig) @cached_property def data_loader_collection(self): self.train.set_transforms( self.transform.transform_factory.train_transform, self.transform.transform_factory.train_target_transform ) self.test.set_transforms( self.transform.transform_factory.test_transform, self.transform.transform_factory.test_target_transform ) if self.valid is not None: self.valid.set_transforms( self.transform.transform_factory.valid_transform, self.transform.transform_factory.valid_target_transform ) return DataLoaderCollection( self.train.loader, self.test.loader, self.valid.loader if self.valid is not None else None )
class MetricsCollectionConfig(cfg.BaseConfig): _train: List[str] = cfg.Option(name="train", default=[], type=cfg.config_list(str)) _test: List[str] = cfg.Option(name="test", default=[], type=cfg.config_list(str)) _valid: List[str] = cfg.Option(name="valid", default=[], type=cfg.config_list(str)) @cached_property def metrics_collection(self): return MetricsCollection(self._get_metric_list(self._train), self._get_metric_list(self._test), self._get_metric_list(self._valid)) def _get_metric_list(self, metric_names: List[str]) -> List[MetricBase]: return [ self._get_metric_cls(metric_name)() for metric_name in metric_names ] def _get_metric_cls(self, metric_name: str) -> Type[MetricBase]: cls = vars(metrics).get(metric_name, NOT_FOUND) if cls is NOT_FOUND: raise KeyError(f"the class {metric_name} was not found.") return cls
class MyConfig(cfg.BaseConfig): _a: int = cfg.Option(name="a", required=True, description="hi", type=int) b: float = cfg.Option(required=False, type=float) c: bool = cfg.Option(required=True, type=bool) li: List[MyConfigEmb] = cfg.Option( type=cfg.config_list(cfg.config_list(MyConfigEmb))) def post_load(self): print(self.a) @property def a(self): return self._a + 1
class TrainerConfig(cfg.BaseConfig): out: Path = cfg.Option(required=True, type=Path) train_rate: float = cfg.Option(default=0.1, type=float) max_epochs: int = cfg.Option(required=True, type=int) save_checkpoint_epochs: int = cfg.Option(default=20, type=int) device: torch.device = cfg.Option(default=DEVICE_AUTO, type=process_device) half: bool = cfg.Option(default=True, type=bool) ddp: bool = cfg.Option(default=True, type=bool) test_no_grad: bool = cfg.Option(default=True, type=bool) # TODO callback def post_load(self): self.half = self.half and self._is_half_available() self.ddp = self.ddp and torch.cuda.is_available( ) and "cuda" in self.device.type def _is_half_available(self): try: import apex return torch.cuda.is_available() except ImportError as _: return False
class TransformConfig(ImporterConfig): param: dict = cfg.Option(default={}, type=cfg.process.flag_container) @cached_property def transform_factory(self) -> TransformsFactory: if issubclass(self.imported, TransformsFactory): return self.imported(**self.param) raise TypeError("imported must be TransformFactory")
class LossFunctionConfig(ImporterConfig): param: dict = cfg.Option(default={}, type=cfg.process.flag_container) @cached_property def loss_func(self): if isinstance(self.imported, type): return self.imported(**self.param) return self.imported def eval_loss(self, data: Tensor, target: Tensor): return self.loss_func(data, target)
class TrainingConfig(cfg.BaseConfig): profile_name: str = cfg.Option(required=True, type=str) _model: ModelConfig = cfg.Option(name="model", required=True, type=ModelConfig) _optimizer: OptimizerConfig = cfg.Option(name="optimizer", required=True, type=OptimizerConfig) dataset: DataLoaderCollectionConfig = cfg.Option( required=True, type=DataLoaderCollectionConfig) metrics: MetricsCollectionConfig = cfg.Option(default={}, type=MetricsCollectionConfig) _trainer: TrainerConfig = cfg.Option(name="trainer", required=True, type=TrainerConfig) _loss_func: LossFunctionConfig = cfg.Option(name="loss_func", required=True, type=LossFunctionConfig) @property def model(self) -> Module: return self._model.model @cached_property def optimizer(self) -> Optimizer: self._optimizer.set_model(self.model) return self._optimizer.optimizer @property def loss_func(self) -> Callable: return self._loss_func.eval_loss @property def data_loader_collection(self) -> DataLoaderCollection: return self.dataset.data_loader_collection @property def metrics_collection(self) -> MetricsCollection: return self.metrics.metrics_collection @cached_property def trainer(self) -> Trainer: return Trainer(self._get_out_prefix(), self.data_loader_collection, self.metrics_collection, self.loss_func, self.model, self.optimizer, self._trainer.out, self._trainer.train_rate, self._trainer.max_epochs, self._trainer.save_checkpoint_epochs, self._trainer.device, self._trainer.half, self._trainer.ddp) def _get_out_prefix(self) -> str: cls_name = self._model.imported.__name__ return self.profile_name + "_" + cls_name
class ModelConfig(ImporterConfig): param: dict = cfg.Option(default={}, type=cfg.process.flag_container) @cached_property def model(self) -> Module: if not isinstance(self.param, dict): raise TypeError("params must be a dict") model = self.imported(**self.param) if not isinstance(model, Module): raise TypeError("imported_cls must be subclass of torch.nn.Module or a Callable returned it.") return model
class DatasetConfig(ImporterConfig): param: dict = cfg.Option(default={}, type=cfg.process.flag_container) batch_size: int = cfg.Option(default=1, type=int) shuffle: bool = cfg.Option(default=False, type=bool) pin_memory: bool = cfg.Option(default=True, type=bool) num_workers: int = cfg.Option(default=0, type=int) transform: Callable = None target_transform: Callable = None def set_transforms(self, transform: Callable, target_transform: Callable): self.transform = transform self.target_transform = target_transform @property def loader(self) -> DataLoader: loader = DataLoader( self.dataset, shuffle=self.shuffle, pin_memory=self.pin_memory, batch_size=self.batch_size, num_workers=self.num_workers ) return loader # TODO transform and target_transform @property def dataset(self) -> Dataset: cls = self.imported if not isinstance(cls, type): raise ValueError("Dataset type should be a class method_name.") if not issubclass(cls, Dataset): raise ValueError(f"Incorrect dataset type: {cls}.") if not isinstance(self.param, MutableMapping): raise ValueError('param must be instance of MutableMapping') dataset = cls(transform=self.transform, target_transform=self.target_transform, **self.param) return dataset
class OptimizerConfig(ImporterConfig): param: dict = cfg.Option(default={}, type=cfg.process.flag_container) model: Module = None def set_model(self, model: Module) -> None: self.model = model @cached_property def optimizer(self) -> Optimizer: if self.model is None: raise ValueError("model has not loaded.") if not isinstance(self.param, dict): raise TypeError("params must be a dict") optimizer = self.imported(self.model.parameters(), **self.param) if not isinstance(optimizer, Optimizer): raise TypeError( "imported_cls must be subclass of torch.optim.Optimizer or a Callable returned it." ) return optimizer
class ImporterConfig(cfg.BaseConfig): _import_statement: str = cfg.Option(required=True, name="import_statement", type=str) statement_re = r"from (\.*((\w+)\.)*\w+) import (\w+)" @cached_property def imported(self): _from, _import = self._split_from_and_import() return vars(importlib.import_module(_from))[_import] def _split_from_and_import(self): import_statement = self._import_statement.strip() result = re.fullmatch(self.statement_re, import_statement) if result is None: raise ValueError( f"Incorrect import_statement: {self._import_statement} ({self.statement_re})" ) g = result.groups() return g[0], g[3]
class PylonCameraConfig(cfg.BaseConfig): exposure_time: int = cfg.Option(default=100, type=int) shutter_mode: str = cfg.Option(default="GlobalResetRelease", type=str) width: int = cfg.Option(required=True, type=int) height: int = cfg.Option(required=True, type=int) delay: float = cfg.Option(required=True, type=float)
class MyConfigEmb(cfg.BaseConfig): x: int = cfg.Option(type=int) y: int = cfg.Option(type=int) i: bool = cfg.Option(type=bool)