Ejemplo n.º 1
0
    def extend(self, *extension: BaseExtension) -> 'BaseRunner':
        """
        Add training extensions to trainer.

        Parameters
        ----------
        extension: BaseExtension
            Extension.

        """
        def _get_keyword_params(func) -> list:
            sig = signature(func)
            return [
                p.name for p in sig.parameters.values()
                if p.kind == p.POSITIONAL_OR_KEYWORD
            ]

        # merge exts to named_exts
        for ext in extension:
            name = camel_to_snake(ext.__class__.__name__)
            methods = [
                'before_proc', 'input_proc', 'step_forward', 'output_proc',
                'after_proc', 'on_reset', 'on_checkpoint'
            ]
            dependencies = [
                _get_keyword_params(getattr(ext, m)) for m in methods
            ]
            dependency_inject = {k: v for k, v in zip(methods, dependencies)}

            self._extensions[name] = (ext, dependency_inject)

        return self
Ejemplo n.º 2
0
 def loss_func(self, loss_func):
     if loss_func is not None:
         self._loss_func = loss_func
         self._loss_type = 'train_' + camel_to_snake(
             loss_func.__class__.__name__)
Ejemplo n.º 3
0
def test_camel_to_snake_1():
    assert 'camel_to_snake' == camel_to_snake('CamelToSnake')