def extend(self, *extension: BaseExtension) -> 'BaseRunner': """ Add training extensions to trainer. Parameters ---------- extension: BaseExtension Extension. """ def _get_keyword_params(func) -> list: sig = signature(func) return [ p.name for p in sig.parameters.values() if p.kind == p.POSITIONAL_OR_KEYWORD ] # merge exts to named_exts for ext in extension: name = camel_to_snake(ext.__class__.__name__) methods = [ 'before_proc', 'input_proc', 'step_forward', 'output_proc', 'after_proc', 'on_reset', 'on_checkpoint' ] dependencies = [ _get_keyword_params(getattr(ext, m)) for m in methods ] dependency_inject = {k: v for k, v in zip(methods, dependencies)} self._extensions[name] = (ext, dependency_inject) return self
def loss_func(self, loss_func): if loss_func is not None: self._loss_func = loss_func self._loss_type = 'train_' + camel_to_snake( loss_func.__class__.__name__)
def test_camel_to_snake_1(): assert 'camel_to_snake' == camel_to_snake('CamelToSnake')