Пример #1
0
    def __init__(
        self,
        hiddens,
        layer_fn=nn.Linear,
        bias=True,
        norm_fn=None,
        activation_fn=nn.ReLU,
        dropout=None,
        layer_order=None,
        residual=False
    ):
        super().__init__()
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        layer_fn = Registry.name2nn(layer_fn)
        activation_fn = Registry.name2nn(activation_fn)
        norm_fn = Registry.name2nn(norm_fn)
        dropout = Registry.name2nn(dropout)

        layer_order = layer_order or ["layer", "norm", "drop", "act"]

        if isinstance(dropout, float):
            dropout_fn = lambda: nn.Dropout(dropout)
        else:
            dropout_fn = dropout

        def _layer_fn(f_in, f_out, bias):
            return layer_fn(f_in, f_out, bias=bias)

        def _normalize_fn(f_in, f_out, bias):
            return norm_fn(f_out) if norm_fn is not None else None

        def _dropout_fn(f_in, f_out, bias):
            return dropout_fn() if dropout_fn is not None else None

        def _activation_fn(f_in, f_out, bias):
            return activation_fn() if activation_fn is not None else None

        name2fn = {
            "layer": _layer_fn,
            "norm": _normalize_fn,
            "drop": _dropout_fn,
            "act": _activation_fn,
        }

        net = []

        for i, (f_in, f_out) in enumerate(pairwise(hiddens)):
            block = []
            for key in layer_order:
                fn = name2fn[key](f_in, f_out, bias)
                if fn is not None:
                    block.append((f"{key}", fn))
            block = torch.nn.Sequential(OrderedDict(block))
            if residual:
                block = ResidualWrapper(net=block)
            net.append((f"block_{i}", block))

        self.net = torch.nn.Sequential(OrderedDict(net))
Пример #2
0
    def __init__(self,
                 hiddens,
                 layer_fn=nn.Linear,
                 bias=True,
                 norm_fn=None,
                 activation_fn=nn.ReLU,
                 dropout=None,
                 layer_order=None,
                 residual=False):
        super().__init__()
        assert len(hiddens) > 1, "No sequence found"

        layer_fn = MODULES.get_if_str(layer_fn)
        activation_fn = MODULES.get_if_str(activation_fn)
        norm_fn = MODULES.get_if_str(norm_fn)
        dropout = MODULES.get_if_str(dropout)
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

        layer_order = layer_order or ["layer", "norm", "drop", "act"]

        if isinstance(dropout, float):
            dropout_fn = lambda: nn.Dropout(dropout)
        else:
            dropout_fn = dropout

        def _layer_fn(f_in, f_out, bias):
            return layer_fn(f_in, f_out, bias=bias)

        def _normalize_fn(f_in, f_out, bias):
            return norm_fn(f_out) if norm_fn is not None else None

        def _dropout_fn(f_in, f_out, bias):
            return dropout_fn() if dropout_fn is not None else None

        def _activation_fn(f_in, f_out, bias):
            return activation_fn() if activation_fn is not None else None

        name2fn = {
            "layer": _layer_fn,
            "norm": _normalize_fn,
            "drop": _dropout_fn,
            "act": _activation_fn,
        }

        net = []

        for i, (f_in, f_out) in enumerate(pairwise(hiddens)):
            block = []
            for key in layer_order:
                fn = name2fn[key](f_in, f_out, bias)
                if fn is not None:
                    block.append((f"{key}", fn))
            block = torch.nn.Sequential(OrderedDict(block))
            if residual:
                block = ResidualWrapper(net=block)
            net.append((f"block_{i}", block))

        self.net = torch.nn.Sequential(OrderedDict(net))
        self.net.apply(inner_init)
Пример #3
0
    def __init__(
        self,
        hiddens,
        layer_fn: Union[str, Dict, List],
        norm_fn: Union[str, Dict, List] = None,
        dropout_fn: Union[str, Dict, List] = None,
        activation_fn: Union[str, Dict, List] = None,
        residual: Union[bool, str] = False,
        layer_order: List = None,
    ):

        super().__init__()
        assert len(hiddens) > 1, "No sequence found"

        # layer params
        layer_fn = _process_additional_params(layer_fn, hiddens[1:])
        # normalization params
        norm_fn = _process_additional_params(norm_fn, hiddens[1:])
        # dropout params
        dropout_fn = _process_additional_params(dropout_fn, hiddens[1:])
        # activation params
        activation_fn = _process_additional_params(activation_fn, hiddens[1:])

        if isinstance(residual, bool) and residual:
            residual = "hard"
            residual = _process_additional_params(residual, hiddens[1:])

        layer_order = layer_order or ["layer", "norm", "drop", "act"]

        def _layer_fn(layer_fn, f_in, f_out, **kwargs):
            layer_fn = MODULES.get_if_str(layer_fn)
            layer_fn = layer_fn(f_in, f_out, **kwargs)
            return layer_fn

        def _normalization_fn(normalization_fn, f_in, f_out, **kwargs):
            normalization_fn = MODULES.get_if_str(normalization_fn)
            normalization_fn = \
                normalization_fn(f_out, **kwargs) \
                if normalization_fn is not None \
                else None
            return normalization_fn

        def _dropout_fn(dropout_fn, f_in, f_out, **kwargs):
            dropout_fn = MODULES.get_if_str(dropout_fn)
            dropout_fn = dropout_fn(**kwargs) \
                if dropout_fn is not None \
                else None
            return dropout_fn

        def _activation_fn(activation_fn, f_in, f_out, **kwargs):
            activation_fn = MODULES.get_if_str(activation_fn)
            activation_fn = activation_fn(**kwargs) \
                if activation_fn is not None \
                else None
            return activation_fn

        name2fn = {
            "layer": _layer_fn,
            "norm": _normalization_fn,
            "drop": _dropout_fn,
            "act": _activation_fn,
        }
        name2params = {
            "layer": layer_fn,
            "norm": norm_fn,
            "drop": dropout_fn,
            "act": activation_fn,
        }

        net = []
        for i, (f_in, f_out) in enumerate(pairwise(hiddens)):
            block = []
            for key in layer_order:
                sub_fn = name2fn[key]
                sub_params = deepcopy(name2params[key][i])

                if isinstance(sub_params, Dict):
                    sub_module = sub_params.pop("module")
                else:
                    sub_module = sub_params
                    sub_params = {}

                sub_block = sub_fn(sub_module, f_in, f_out, **sub_params)
                if sub_block is not None:
                    block.append((f"{key}", sub_block))

            block_ = OrderedDict(block)
            block = torch.nn.Sequential(block_)

            if block_.get("act", None) is not None:
                activation = block_["act"]
                activation_init = \
                    create_optimal_inner_init(nonlinearity=activation)
                block.apply(activation_init)

            if residual == "hard" or (residual == "soft" and f_in == f_out):
                block = ResidualWrapper(net=block)
            net.append((f"block_{i}", block))

        self.net = torch.nn.Sequential(OrderedDict(net))