class DataLoaderCollectionConfig(cfg.BaseConfig):
    transform: TransformConfig = cfg.Option(required=True, type=TransformConfig)

    train: DatasetConfig = cfg.Option(required=True, type=DatasetConfig)
    test: DatasetConfig = cfg.Option(required=True, type=DatasetConfig)
    valid: Optional[DatasetConfig] = cfg.Option(nullable=True, type=DatasetConfig)

    @cached_property
    def data_loader_collection(self):
        self.train.set_transforms(
            self.transform.transform_factory.train_transform,
            self.transform.transform_factory.train_target_transform
        )
        self.test.set_transforms(
            self.transform.transform_factory.test_transform,
            self.transform.transform_factory.test_target_transform
        )

        if self.valid is not None:
            self.valid.set_transforms(
                self.transform.transform_factory.valid_transform,
                self.transform.transform_factory.valid_target_transform
            )

        return DataLoaderCollection(
            self.train.loader,
            self.test.loader,
            self.valid.loader if self.valid is not None else None
        )
Example #2
0
class MetricsCollectionConfig(cfg.BaseConfig):
    _train: List[str] = cfg.Option(name="train",
                                   default=[],
                                   type=cfg.config_list(str))
    _test: List[str] = cfg.Option(name="test",
                                  default=[],
                                  type=cfg.config_list(str))
    _valid: List[str] = cfg.Option(name="valid",
                                   default=[],
                                   type=cfg.config_list(str))

    @cached_property
    def metrics_collection(self):
        return MetricsCollection(self._get_metric_list(self._train),
                                 self._get_metric_list(self._test),
                                 self._get_metric_list(self._valid))

    def _get_metric_list(self, metric_names: List[str]) -> List[MetricBase]:
        return [
            self._get_metric_cls(metric_name)() for metric_name in metric_names
        ]

    def _get_metric_cls(self, metric_name: str) -> Type[MetricBase]:
        cls = vars(metrics).get(metric_name, NOT_FOUND)

        if cls is NOT_FOUND:
            raise KeyError(f"the class {metric_name} was not found.")
        return cls
Example #3
0
class MyConfig(cfg.BaseConfig):
    _a: int = cfg.Option(name="a", required=True, description="hi", type=int)
    b: float = cfg.Option(required=False, type=float)
    c: bool = cfg.Option(required=True, type=bool)
    li: List[MyConfigEmb] = cfg.Option(
        type=cfg.config_list(cfg.config_list(MyConfigEmb)))

    def post_load(self):
        print(self.a)

    @property
    def a(self):
        return self._a + 1
Example #4
0
class TrainerConfig(cfg.BaseConfig):
    out: Path = cfg.Option(required=True, type=Path)
    train_rate: float = cfg.Option(default=0.1, type=float)

    max_epochs: int = cfg.Option(required=True, type=int)
    save_checkpoint_epochs: int = cfg.Option(default=20, type=int)

    device: torch.device = cfg.Option(default=DEVICE_AUTO, type=process_device)
    half: bool = cfg.Option(default=True, type=bool)
    ddp: bool = cfg.Option(default=True, type=bool)
    test_no_grad: bool = cfg.Option(default=True, type=bool)

    # TODO callback

    def post_load(self):
        self.half = self.half and self._is_half_available()
        self.ddp = self.ddp and torch.cuda.is_available(
        ) and "cuda" in self.device.type

    def _is_half_available(self):
        try:
            import apex
            return torch.cuda.is_available()
        except ImportError as _:
            return False
Example #5
0
class TransformConfig(ImporterConfig):
    param: dict = cfg.Option(default={}, type=cfg.process.flag_container)

    @cached_property
    def transform_factory(self) -> TransformsFactory:
        if issubclass(self.imported, TransformsFactory):
            return self.imported(**self.param)
        raise TypeError("imported must be TransformFactory")
Example #6
0
class LossFunctionConfig(ImporterConfig):
    param: dict = cfg.Option(default={}, type=cfg.process.flag_container)

    @cached_property
    def loss_func(self):
        if isinstance(self.imported, type):
            return self.imported(**self.param)
        return self.imported

    def eval_loss(self, data: Tensor, target: Tensor):
        return self.loss_func(data, target)
Example #7
0
class TrainingConfig(cfg.BaseConfig):
    profile_name: str = cfg.Option(required=True, type=str)
    _model: ModelConfig = cfg.Option(name="model",
                                     required=True,
                                     type=ModelConfig)
    _optimizer: OptimizerConfig = cfg.Option(name="optimizer",
                                             required=True,
                                             type=OptimizerConfig)
    dataset: DataLoaderCollectionConfig = cfg.Option(
        required=True, type=DataLoaderCollectionConfig)
    metrics: MetricsCollectionConfig = cfg.Option(default={},
                                                  type=MetricsCollectionConfig)

    _trainer: TrainerConfig = cfg.Option(name="trainer",
                                         required=True,
                                         type=TrainerConfig)
    _loss_func: LossFunctionConfig = cfg.Option(name="loss_func",
                                                required=True,
                                                type=LossFunctionConfig)

    @property
    def model(self) -> Module:
        return self._model.model

    @cached_property
    def optimizer(self) -> Optimizer:
        self._optimizer.set_model(self.model)
        return self._optimizer.optimizer

    @property
    def loss_func(self) -> Callable:
        return self._loss_func.eval_loss

    @property
    def data_loader_collection(self) -> DataLoaderCollection:
        return self.dataset.data_loader_collection

    @property
    def metrics_collection(self) -> MetricsCollection:
        return self.metrics.metrics_collection

    @cached_property
    def trainer(self) -> Trainer:
        return Trainer(self._get_out_prefix(), self.data_loader_collection,
                       self.metrics_collection, self.loss_func, self.model,
                       self.optimizer, self._trainer.out,
                       self._trainer.train_rate, self._trainer.max_epochs,
                       self._trainer.save_checkpoint_epochs,
                       self._trainer.device, self._trainer.half,
                       self._trainer.ddp)

    def _get_out_prefix(self) -> str:
        cls_name = self._model.imported.__name__
        return self.profile_name + "_" + cls_name
Example #8
0
class ModelConfig(ImporterConfig):
    param: dict = cfg.Option(default={}, type=cfg.process.flag_container)

    @cached_property
    def model(self) -> Module:
        if not isinstance(self.param, dict):
            raise TypeError("params must be a dict")
        model = self.imported(**self.param)

        if not isinstance(model, Module):
            raise TypeError("imported_cls must be subclass of torch.nn.Module or a Callable returned it.")

        return model
Example #9
0
class DatasetConfig(ImporterConfig):
    param: dict = cfg.Option(default={}, type=cfg.process.flag_container)

    batch_size: int = cfg.Option(default=1, type=int)
    shuffle: bool = cfg.Option(default=False, type=bool)
    pin_memory: bool = cfg.Option(default=True, type=bool)
    num_workers: int = cfg.Option(default=0, type=int)
    transform: Callable = None
    target_transform: Callable = None

    def set_transforms(self, transform: Callable, target_transform: Callable):
        self.transform = transform
        self.target_transform = target_transform

    @property
    def loader(self) -> DataLoader:
        loader = DataLoader(
            self.dataset,
            shuffle=self.shuffle,
            pin_memory=self.pin_memory,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )

        return loader

    # TODO transform and target_transform
    @property
    def dataset(self) -> Dataset:
        cls = self.imported
        if not isinstance(cls, type):
            raise ValueError("Dataset type should be a class method_name.")
        if not issubclass(cls, Dataset):
            raise ValueError(f"Incorrect dataset type: {cls}.")
        if not isinstance(self.param, MutableMapping):
            raise ValueError('param must be instance of MutableMapping')
        dataset = cls(transform=self.transform, target_transform=self.target_transform, **self.param)

        return dataset
Example #10
0
class OptimizerConfig(ImporterConfig):
    param: dict = cfg.Option(default={}, type=cfg.process.flag_container)
    model: Module = None

    def set_model(self, model: Module) -> None:
        self.model = model

    @cached_property
    def optimizer(self) -> Optimizer:
        if self.model is None:
            raise ValueError("model has not loaded.")

        if not isinstance(self.param, dict):
            raise TypeError("params must be a dict")

        optimizer = self.imported(self.model.parameters(), **self.param)
        if not isinstance(optimizer, Optimizer):
            raise TypeError(
                "imported_cls must be subclass of torch.optim.Optimizer or a Callable returned it."
            )
        return optimizer
Example #11
0
class ImporterConfig(cfg.BaseConfig):
    _import_statement: str = cfg.Option(required=True,
                                        name="import_statement",
                                        type=str)

    statement_re = r"from (\.*((\w+)\.)*\w+) import (\w+)"

    @cached_property
    def imported(self):
        _from, _import = self._split_from_and_import()
        return vars(importlib.import_module(_from))[_import]

    def _split_from_and_import(self):
        import_statement = self._import_statement.strip()
        result = re.fullmatch(self.statement_re, import_statement)
        if result is None:
            raise ValueError(
                f"Incorrect import_statement: {self._import_statement} ({self.statement_re})"
            )
        g = result.groups()
        return g[0], g[3]
Example #12
0
class PylonCameraConfig(cfg.BaseConfig):
    exposure_time: int = cfg.Option(default=100, type=int)
    shutter_mode: str = cfg.Option(default="GlobalResetRelease", type=str)
    width: int = cfg.Option(required=True, type=int)
    height: int = cfg.Option(required=True, type=int)
    delay: float = cfg.Option(required=True, type=float)
Example #13
0
class MyConfigEmb(cfg.BaseConfig):
    x: int = cfg.Option(type=int)
    y: int = cfg.Option(type=int)
    i: bool = cfg.Option(type=bool)