Example #1
0
    def __init__(
        self,
        in_features: int,
        out_features: int,
        policy_type: str = None,
        out_activation: nn.Module = None
    ):
        super().__init__()
        assert policy_type in [
            "categorical", "gauss", "real_nvp", "logits", None
        ]

        # @TODO: refactor
        layer_fn = nn.Linear
        activation_fn = nn.ReLU
        squashing_fn = nn.Tanh
        bias = True

        if policy_type == "categorical":
            head_size = out_features
            policy_net = CategoricalPolicy()
        elif policy_type == "gauss":
            head_size = out_features * 2
            policy_net = GaussPolicy(squashing_fn)
        elif policy_type == "real_nvp":
            head_size = out_features * 2
            policy_net = RealNVPPolicy(
                action_size=out_features,
                layer_fn=layer_fn,
                activation_fn=activation_fn,
                squashing_fn=squashing_fn,
                bias=bias
            )
        else:
            head_size = out_features
            policy_net = None
            policy_type = "logits"

        self.policy_type = policy_type

        head_net = SequentialNet(
            hiddens=[in_features, head_size],
            layer_fn=nn.Linear,
            activation_fn=out_activation,
            norm_fn=None,
            bias=True
        )
        head_net.apply(outer_init)
        self.head_net = head_net

        self.policy_net = policy_net
        self._policy_fn = None
        if policy_net is None:
            self._policy_fn = lambda *args: args[0]
        elif isinstance(
            policy_net, (CategoricalPolicy, GaussPolicy, RealNVPPolicy)
        ):
            self._policy_fn = policy_net.forward
        else:
            raise NotImplementedError
Example #2
0
    def get_from_params(
        cls,
        observation_net_params=None,
        aggregation_net_params=None,
        main_net_params=None,
    ) -> "StateNet":
        assert observation_net_params is not None
        assert aggregation_net_params is None, "Lama is not implemented yet"

        observation_net = SequentialNet(**observation_net_params)
        main_net = SequentialNet(**main_net_params)
        net = cls(main_net=main_net, observation_net=observation_net)
        return net
Example #3
0
    def __init__(self,
                 action_size,
                 layer_fn,
                 activation_fn=nn.ReLU,
                 bias=True,
                 parity="odd"):
        """
        Conditional affine coupling layer used in Real NVP Bijector.
        Original paper: https://arxiv.org/abs/1605.08803
        Adaptation to RL: https://arxiv.org/abs/1804.02808
        Important notes
        ---------------
        1. State embeddings are supposed to have size (action_size * 2).
        2. Scale and translation networks used in the Real NVP Bijector
        both have one hidden layer of (action_size) (activation_fn) units.
        3. Parity ("odd" or "even") determines which part of the input
        is being copied and which is being transformed.
        """
        super().__init__()

        layer_fn = MODULES.get_if_str(layer_fn)
        activation_fn = MODULES.get_if_str(activation_fn)

        self.parity = parity
        if self.parity == "odd":
            self.copy_size = action_size // 2
        else:
            self.copy_size = action_size - action_size // 2

        self.scale_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn=layer_fn,
            activation_fn=activation_fn,
            norm_fn=None,
            bias=bias)
        self.scale_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn=layer_fn,
            activation_fn=None,
            norm_fn=None,
            bias=True)

        self.translation_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn=layer_fn,
            activation_fn=activation_fn,
            norm_fn=None,
            bias=bias)
        self.translation_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn=layer_fn,
            activation_fn=None,
            norm_fn=None,
            bias=True)

        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)
        self.scale_prenet.apply(inner_init)
        self.scale_net.apply(outer_init)
        self.translation_prenet.apply(inner_init)
        self.translation_net.apply(outer_init)
Example #4
0
    def get_from_params(
        cls,
        image_size: int = None,
        encoder_params: Dict = None,
        embedding_net_params: Dict = None,
        heads_params: Dict = None,
    ) -> "MultiHeadNet":

        encoder_params_ = deepcopy(encoder_params)
        embedding_net_params_ = deepcopy(embedding_net_params)
        heads_params_ = deepcopy(heads_params)

        model_name = encoder_params_.pop('model')

        encoder_net = registry.MODELS.get_instance(model_name,
                                                   **encoder_params_)

        enc_size = embedding_net_params_.pop('input_channels')
        embedding_net_params_["hiddens"].insert(0, enc_size)
        embedding_net = SequentialNet(**embedding_net_params_)
        emb_size = embedding_net_params_["hiddens"][-1]

        head_kwargs_ = {}
        for key, value in heads_params_.items():
            head_kwargs_[key] = nn.Linear(emb_size, value, bias=True)
        head_nets = nn.ModuleDict(head_kwargs_)

        net = cls(
            encoder_net=encoder_net,
            embedding_net=embedding_net,
            head_nets=head_nets,
        )

        return net
Example #5
0
    def get_from_params(
        cls,
        image_size: int = None,
        encoder_params: Dict = None,
        embedding_net_params: Dict = None,
        heads_params: Dict = None,
    ) -> "MultiHeadNet":

        encoder_params_ = deepcopy(encoder_params)
        embedding_net_params_ = deepcopy(embedding_net_params)
        heads_params_ = deepcopy(heads_params)

        encoder_net = ResnetEncoder(**encoder_params_)
        encoder_input_shape = (3, image_size, image_size)
        encoder_output = utils.get_network_output(encoder_net,
                                                  encoder_input_shape)
        enc_size = encoder_output.nelement()
        embedding_net_params_["hiddens"].insert(0, enc_size)
        embedding_net = SequentialNet(**embedding_net_params_)
        emb_size = embedding_net_params_["hiddens"][-1]

        head_kwargs_ = {}
        for key, value in heads_params_.items():
            head_kwargs_[key] = nn.Linear(emb_size, value, bias=True)
        head_nets = nn.ModuleDict(head_kwargs_)

        net = cls(
            encoder_net=encoder_net,
            embedding_net=embedding_net,
            head_nets=head_nets,
        )

        return net
Example #6
0
 def __init__(self,
              enc,
              n_cls,
              hiddens,
              emb_size,
              activation_fn=torch.nn.ReLU,
              norm_fn=None,
              bias=True,
              dropout=None):
     super().__init__()
     self.encoder = enc
     self.emb_net = SequentialNet(hiddens=hiddens + [emb_size],
                                  activation_fn=activation_fn,
                                  norm_fn=norm_fn,
                                  bias=bias,
                                  dropout=dropout)
     self.head = nn.Linear(emb_size, n_cls, bias=True)
Example #7
0
    def __init__(self, encoder_params, embedding_net_params, heads_params):
        super().__init__()

        encoder_params_ = deepcopy(encoder_params)
        embedding_net_params_ = deepcopy(embedding_net_params)
        heads_params_ = deepcopy(heads_params)

        self.encoder_net = encoder = ResnetEncoder(**encoder_params_)
        self.enc_size = encoder.out_features

        if self.enc_size is not None:
            embedding_net_params_["hiddens"].insert(0, self.enc_size)

        self.embedding_net = SequentialNet(**embedding_net_params_)
        self.emb_size = embedding_net_params_["hiddens"][-1]

        head_kwargs_ = {}
        for key, value in heads_params_.items():
            head_kwargs_[key] = nn.Linear(self.emb_size, value, bias=True)
        self.heads = nn.ModuleDict(head_kwargs_)
Example #8
0
    def create_from_params(cls,
                           state_shape,
                           observation_hiddens=None,
                           head_hiddens=None,
                           layer_fn=nn.Linear,
                           activation_fn=nn.ReLU,
                           dropout=None,
                           norm_fn=None,
                           bias=True,
                           layer_order=None,
                           residual=False,
                           out_activation=None,
                           history_aggregation_type=None,
                           lama_poolings=None,
                           **kwargs):
        assert len(kwargs) == 0
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        observation_hiddens = observation_hiddens or []
        head_hiddens = head_hiddens or []

        layer_fn = Registry.name2nn(layer_fn)
        activation_fn = Registry.name2nn(activation_fn)
        norm_fn = Registry.name2nn(norm_fn)
        out_activation = Registry.name2nn(out_activation)
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

        if isinstance(state_shape, int):
            state_shape = (state_shape, )

        if len(state_shape) in [1, 2]:
            # linear case: one observation or several one
            # state_shape like [history_len, obs_shape]
            # @TODO: handle lama/rnn correctly
            if not history_aggregation_type:
                state_size = reduce(lambda x, y: x * y, state_shape)
            else:
                state_size = reduce(lambda x, y: x * y, state_shape[1:])

            if len(observation_hiddens) > 0:
                observation_net = SequentialNet(hiddens=[state_size] +
                                                observation_hiddens,
                                                layer_fn=layer_fn,
                                                dropout=dropout,
                                                activation_fn=activation_fn,
                                                norm_fn=norm_fn,
                                                bias=bias,
                                                layer_order=layer_order,
                                                residual=residual)
                observation_net.apply(inner_init)
                obs_out = observation_hiddens[-1]
            else:
                observation_net = None
                obs_out = state_size

        elif len(state_shape) in [3, 4]:
            # cnn case: one image or several one @TODO
            raise NotImplementedError
        else:
            raise NotImplementedError

        assert obs_out

        if history_aggregation_type == "lama_obs":
            aggregation_net = LamaPooling(features_in=obs_out,
                                          poolings=lama_poolings)
            aggregation_out = aggregation_net.features_out
        else:
            aggregation_net = None
            aggregation_out = obs_out

        main_net = SequentialNet(hiddens=[aggregation_out] + head_hiddens[:-1],
                                 layer_fn=layer_fn,
                                 dropout=dropout,
                                 activation_fn=activation_fn,
                                 norm_fn=norm_fn,
                                 bias=bias,
                                 layer_order=layer_order,
                                 residual=residual)
        main_net.apply(inner_init)

        # @TODO: place for memory network

        head_net = SequentialNet(hiddens=[head_hiddens[-2], head_hiddens[-1]],
                                 layer_fn=nn.Linear,
                                 activation_fn=out_activation,
                                 norm_fn=None,
                                 bias=True)
        head_net.apply(outer_init)

        critic_net = cls(observation_net=observation_net,
                         aggregation_net=aggregation_net,
                         main_net=main_net,
                         head_net=head_net,
                         policy_net=None)

        return critic_net
Example #9
0
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 policy_type: str = None,
                 out_activation: nn.Module = None):
        super().__init__()
        assert policy_type in [
            "categorical", "bernoulli", "diagonal-gauss", "squashing-gauss",
            "real-nvp", "logits", None
        ]

        # @TODO: refactor
        layer_fn = nn.Linear
        activation_fn = nn.ReLU
        squashing_fn = out_activation
        bias = True

        if policy_type == "categorical":
            assert out_activation is None
            head_size = out_features
            policy_net = CategoricalPolicy()
        elif policy_type == "bernoulli":
            assert out_activation is None
            head_size = out_features
            policy_net = BernoulliPolicy()
        elif policy_type == "diagonal-gauss":
            head_size = out_features * 2
            policy_net = DiagonalGaussPolicy()
        elif policy_type == "squashing-gauss":
            out_activation = None
            head_size = out_features * 2
            policy_net = SquashingGaussPolicy(squashing_fn)
        elif policy_type == "real-nvp":
            out_activation = None
            head_size = out_features * 2
            policy_net = RealNVPPolicy(action_size=out_features,
                                       layer_fn=layer_fn,
                                       activation_fn=activation_fn,
                                       squashing_fn=squashing_fn,
                                       bias=bias)
        else:
            head_size = out_features
            policy_net = None
            policy_type = "logits"

        self.policy_type = policy_type

        head_net = SequentialNet(
            hiddens=[in_features, head_size],
            layer_fn={
                "module": layer_fn,
                "bias": True
            },
            activation_fn=out_activation,
            norm_fn=None,
        )
        head_net.apply(outer_init)
        self.head_net = head_net

        self.policy_net = policy_net
        self._policy_fn = None
        if policy_net is not None:
            self._policy_fn = policy_net.forward
        else:
            self._policy_fn = lambda *args: args[0]
Example #10
0
    def create_from_params(
        cls,
        state_shape,
        action_size,
        observation_hiddens=None,
        head_hiddens=None,
        layer_fn=nn.Linear,
        activation_fn=nn.ReLU,
        dropout=None,
        norm_fn=None,
        bias=True,
        layer_order=None,
        residual=False,
        out_activation=None,
        observation_aggregation=None,
        lama_poolings=None,
        policy_type=None,
        squashing_fn=nn.Tanh,
        **kwargs
    ):
        assert len(kwargs) == 0

        observation_hiddens = observation_hiddens or []
        head_hiddens = head_hiddens or []

        layer_fn = MODULES.get_if_str(layer_fn)
        activation_fn = MODULES.get_if_str(activation_fn)
        norm_fn = MODULES.get_if_str(norm_fn)
        out_activation = MODULES.get_if_str(out_activation)
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

        if isinstance(state_shape, int):
            state_shape = (state_shape,)

        if len(state_shape) in [1, 2]:
            # linear case: one observation or several one
            # state_shape like [history_len, obs_shape]
            # @TODO: handle lama/rnn correctly
            if not observation_aggregation:
                observation_size = reduce(lambda x, y: x * y, state_shape)
            else:
                observation_size = reduce(lambda x, y: x * y, state_shape[1:])

            if len(observation_hiddens) > 0:
                observation_net = SequentialNet(
                    hiddens=[observation_size] + observation_hiddens,
                    layer_fn=layer_fn,
                    dropout=dropout,
                    activation_fn=activation_fn,
                    norm_fn=norm_fn,
                    bias=bias,
                    layer_order=layer_order,
                    residual=residual
                )
                observation_net.apply(inner_init)
                obs_out = observation_hiddens[-1]
            else:
                observation_net = None
                obs_out = observation_size

        elif len(state_shape) in [3, 4]:
            # cnn case: one image or several one @TODO
            raise NotImplementedError
        else:
            raise NotImplementedError

        assert obs_out

        if observation_aggregation == "lama_obs":
            aggregation_net = LamaPooling(
                features_in=obs_out,
                poolings=lama_poolings
            )
            aggregation_out = aggregation_net.features_out
        else:
            aggregation_net = None
            aggregation_out = obs_out

        main_net = SequentialNet(
            hiddens=[aggregation_out] + head_hiddens,
            layer_fn=layer_fn,
            dropout=dropout,
            activation_fn=activation_fn,
            norm_fn=norm_fn,
            bias=bias,
            layer_order=layer_order,
            residual=residual
        )
        main_net.apply(inner_init)

        # @TODO: place for memory network

        if policy_type == "gauss":
            head_size = action_size * 2
            policy_net = GaussPolicy(squashing_fn)
        elif policy_type == "real_nvp":
            head_size = action_size * 2
            policy_net = RealNVPPolicy(
                action_size=action_size,
                layer_fn=layer_fn,
                activation_fn=activation_fn,
                squashing_fn=squashing_fn,
                norm_fn=None,
                bias=bias
            )
        else:
            head_size = action_size
            policy_net = None

        head_net = SequentialNet(
            hiddens=[head_hiddens[-1], head_size],
            layer_fn=nn.Linear,
            activation_fn=out_activation,
            norm_fn=None,
            bias=True
        )
        head_net.apply(outer_init)

        actor_net = cls(
            observation_net=observation_net,
            aggregation_net=aggregation_net,
            main_net=main_net,
            head_net=head_net,
            policy_net=policy_net
        )

        return actor_net
Example #11
0
    def __init__(self,
                 encoder,
                 num_classes,
                 feature_net_hiddens=None,
                 emb_net_hiddens=None,
                 activation_fn=torch.nn.ReLU,
                 norm_fn=None,
                 bias=True,
                 dropout=None,
                 consensus=None,
                 kernel_size=1,
                 feature_net_skip_connection=False,
                 early_consensus=True):
        super().__init__()

        assert consensus is not None
        assert kernel_size in [1, 3, 5]

        consensus = consensus if isinstance(consensus, list) else [consensus]
        self.consensus = consensus

        self.encoder = encoder
        self.dropout = nn.Dropout(dropout)
        self.feature_net_skip_connection = feature_net_skip_connection
        self.early_consensus = early_consensus

        nonlinearity = registry.MODULES.get_if_str(activation_fn)
        inner_init = create_optimal_inner_init(nonlinearity=nonlinearity)
        kernel2pad = {1: 0, 3: 1, 5: 2}

        def layer_fn(in_features, out_features, bias=True):

            return nn.Conv1d(in_features,
                             out_features,
                             bias=bias,
                             kernel_size=kernel_size,
                             padding=kernel2pad[kernel_size])

        if feature_net_hiddens is not None:
            self.feature_net = SequentialNet(
                hiddens=[encoder.out_features] + [feature_net_hiddens],
                layer_fn=layer_fn,
                norm_fn=norm_fn,
                activation_fn=activation_fn,
            )
            self.feature_net.apply(inner_init)
            out_features = feature_net_hiddens
        else:
            # if no feature net, then no need of skip connection
            # (nothing to skip)
            assert not self.feature_net_skip_connection
            self.feature_net = lambda x: x
            out_features = encoder.out_features

        # Differences are starting here

        # Input channels to consensus function
        # (also to embedding net multiplied by len(consensus))
        if self.feature_net_skip_connection:
            in_channels = out_features + encoder.out_features
        else:
            in_channels = out_features

        consensus_fn = OrderedDict()
        for key in sorted(consensus):
            if key == "attention":
                self.attn = nn.Sequential(
                    nn.Conv1d(in_channels=in_channels,
                              out_channels=1,
                              kernel_size=kernel_size,
                              padding=kernel2pad[kernel_size],
                              bias=True), nn.Softmax(dim=1))

                def self_attn_fn(x):
                    x_a = x.transpose(1, 2)
                    x_attn = (self.attn(x_a) * x_a)
                    x_attn = x_attn.transpose(1, 2)
                    x_attn = x_attn.mean(1, keepdim=True)
                    return x_attn

                consensus_fn["attention"] = self_attn_fn
            elif key == "avg":
                consensus_fn[key] = lambda x: x.mean(1, keepdim=True)
            elif key == "max":
                consensus_fn[key] = lambda x: x.max(1, keepdim=True)[0]

        # Not optimized if too more understandable logic
        if self.early_consensus:
            out_features = emb_net_hiddens

            self.emb_net = SequentialNet(
                hiddens=[in_channels * len(consensus_fn), emb_net_hiddens],
                layer_fn=nn.Linear,
                norm_fn=norm_fn,
                activation_fn=activation_fn,
            )
            self.emb_net.apply(inner_init)
        else:
            if self.feature_net_skip_connection:
                out_features = out_features + self.encoder.out_features
            else:
                out_features = out_features

        self.head = nn.Linear(out_features, num_classes, bias=True)

        if 'attention' in consensus:
            self.attn.apply(outer_init)
        self.head.apply(outer_init)

        self.consensus_fn = consensus_fn
Example #12
0
class CouplingLayer(nn.Module):
    def __init__(self,
                 action_size,
                 layer_fn,
                 activation_fn=nn.ReLU,
                 bias=True,
                 parity="odd"):
        """
        Conditional affine coupling layer used in Real NVP Bijector.
        Original paper: https://arxiv.org/abs/1605.08803
        Adaptation to RL: https://arxiv.org/abs/1804.02808
        Important notes
        ---------------
        1. State embeddings are supposed to have size (action_size * 2).
        2. Scale and translation networks used in the Real NVP Bijector
        both have one hidden layer of (action_size) (activation_fn) units.
        3. Parity ("odd" or "even") determines which part of the input
        is being copied and which is being transformed.
        """
        super().__init__()

        layer_fn = MODULES.get_if_str(layer_fn)
        activation_fn = MODULES.get_if_str(activation_fn)

        self.parity = parity
        if self.parity == "odd":
            self.copy_size = action_size // 2
        else:
            self.copy_size = action_size - action_size // 2

        self.scale_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn=layer_fn,
            activation_fn=activation_fn,
            norm_fn=None,
            bias=bias)
        self.scale_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn=layer_fn,
            activation_fn=None,
            norm_fn=None,
            bias=True)

        self.translation_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn=layer_fn,
            activation_fn=activation_fn,
            norm_fn=None,
            bias=bias)
        self.translation_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn=layer_fn,
            activation_fn=None,
            norm_fn=None,
            bias=True)

        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)
        self.scale_prenet.apply(inner_init)
        self.scale_net.apply(outer_init)
        self.translation_prenet.apply(inner_init)
        self.translation_net.apply(outer_init)

    def forward(self, action, state_embedding, log_pi):
        if self.parity == "odd":
            action_copy = action[:, :self.copy_size]
            action_transform = action[:, self.copy_size:]
        else:
            action_copy = action[:, -self.copy_size:]
            action_transform = action[:, :-self.copy_size]

        x = torch.cat((state_embedding, action_copy), dim=1)

        t = self.translation_prenet(x)
        t = self.translation_net(t)

        s = self.scale_prenet(x)
        s = self.scale_net(s)

        out_transform = t + action_transform * torch.exp(s)

        if self.parity == "odd":
            action = torch.cat((action_copy, out_transform), dim=1)
        else:
            action = torch.cat((out_transform, action_copy), dim=1)

        log_det_jacobian = s.sum(dim=1)
        log_pi = log_pi - log_det_jacobian

        return action, log_pi
Example #13
0
    def get_from_params(
        cls,
        backbone_params: Dict = None,
        neck_params: Dict = None,
        heads_params: Dict = None,
    ) -> "GenericModel":

        backbone_params_ = deepcopy(backbone_params)
        neck_params_ = deepcopy(neck_params)
        heads_params_ = deepcopy(heads_params)

        if "requires_grad" in backbone_params_:
            requires_grad = backbone_params_.pop("requires_grad")
        else:
            requires_grad = False

        if "pretrained" in backbone_params_:
            pretrained = backbone_params_.pop("pretrained")
        else:
            pretrained = True

        if backbone_params_["model_name"] in pretrainedmodels.__dict__:
            model_name = backbone_params_.pop("model_name")

            backbone = pretrainedmodels.__dict__[model_name](
                num_classes=1000,
                pretrained="imagenet" if pretrained else None)

            enc_size = backbone.last_linear.in_features

        # elif backbone_params_["model_name"].startswith("efficientnet"):
        #     if pretrained is not None:
        #         backbone = EfficientNet.from_pretrained(**backbone_params_)
        #     else:
        #         backbone = EfficientNet.from_name(**backbone_params_)
        #
        #     backbone.set_swish(memory_efficient=True)
        #
        #     if in_channels != 3:
        #         Conv2d = get_same_padding_conv2d(
        #             image_size=backbone._global_params.image_size)
        #         out_channels = round_filters(32, backbone._global_params)
        #         backbone._conv_stem = Conv2d(in_channels, out_channels,
        #                                      kernel_size=3,
        #                                      stride=2, bias=False)
        #
        #     enc_size = backbone._conv_head.out_channels
        else:
            raise NotImplementedError("This model not yet implemented")

        del backbone.last_linear
        # backbone._adapt_avg_pooling = nn.AdaptiveAvgPool2d(1)
        # backbone._dropout = nn.Dropout(p=0.2)

        neck = None
        if neck_params_:
            neck_params_["hiddens"].insert(0, enc_size)
            emb_size = neck_params_["hiddens"][-1]

            if neck_params_ is not None:
                neck = SequentialNet(**neck_params_)
            # neck.requires_grad = requires_grad
        else:
            emb_size = enc_size

        if heads_params_ is not None:
            head_kwargs_ = {}
            for head, params in heads_params_.items():
                if isinstance(heads_params_, int):
                    head_kwargs_[head] = nn.Linear(emb_size, params, bias=True)
                elif isinstance(heads_params_, dict):
                    params["hiddens"].insert(0, emb_size)
                    head_kwargs_[head] = SequentialNet(**params)
                # head_kwargs_[head].requires_grad = requires_grad
            heads = nn.ModuleDict(head_kwargs_)
        else:
            heads = None

        model = cls(backbone=backbone, neck=neck, heads=heads)

        utils.set_requires_grad(model, requires_grad)

        print(model)

        return model
Example #14
0
def test_config2():
    config2 = {
        "in_features": 16,
        "heads_params": {
            "head1": {
                "hiddens": [2],
                "layer_fn": {
                    "module": "Linear",
                    "bias": True
                },
            },
            "_head2": {
                "_hidden": {
                    "hiddens": [16],
                    "layer_fn": {
                        "module": "Linear",
                        "bias": False
                    },
                },
                "head2_1": {
                    "hiddens": [32],
                    "layer_fn": {
                        "module": "Linear",
                        "bias": True
                    },
                    "normalize_output": True
                },
                "_head2_2": {
                    "_hidden": {
                        "hiddens": [16, 16, 16],
                        "layer_fn": {
                            "module": "Linear",
                            "bias": False
                        },
                    },
                    "head2_2_1": {
                        "hiddens": [32],
                        "layer_fn": {
                            "module": "Linear",
                            "bias": True
                        },
                        "normalize_output": False,
                    },
                },
            },
        },
    }

    hydra = Hydra.get_from_params(**config2)

    config2_ = copy.deepcopy(config2)
    _pop_normalization(config2_)
    heads_params = config2_["heads_params"]
    heads_params["head1"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_hidden"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["head2_1"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_head2_2"]["_hidden"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_head2_2"]["head2_2_1"]["hiddens"].insert(0, 16)

    net = nn.ModuleDict({
        "encoder":
        nn.Sequential(),
        "heads":
        nn.ModuleDict({
            "head1":
            nn.Sequential(
                OrderedDict([
                    ("net", SequentialNet(**heads_params["head1"])),
                ])),
            "_head2":
            nn.ModuleDict({
                "_hidden":
                nn.Sequential(
                    OrderedDict([
                        ("net",
                         SequentialNet(**heads_params["_head2"]["_hidden"]))
                    ])),
                "head2_1":
                nn.Sequential(
                    OrderedDict([
                        ("net",
                         SequentialNet(**heads_params["_head2"]["head2_1"])),
                        ("normalize", Normalize()),
                    ])),
                "_head2_2":
                nn.ModuleDict({
                    "_hidden":
                    nn.Sequential(
                        OrderedDict([("net",
                                      SequentialNet(**heads_params["_head2"]
                                                    ["_head2_2"]["_hidden"]))
                                     ])),
                    "head2_2_1":
                    nn.Sequential(
                        OrderedDict([("net",
                                      SequentialNet(**heads_params["_head2"]
                                                    ["_head2_2"]["head2_2_1"]))
                                     ])),
                })
            })
        })
    })

    _check_named_parameters(hydra.encoder, net["encoder"])
    _check_named_parameters(hydra.heads, net["heads"])
    assert hydra.embedders == {}

    input_ = torch.rand(1, 16)

    output_kv = hydra(input_)
    assert (input_ == output_kv["features"]).sum().item() == 16
    assert (input_ == output_kv["embeddings"]).sum().item() == 16
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    with pytest.raises(KeyError):
        output_kv = hydra(input_, target1=torch.ones(1, 2).long())
    with pytest.raises(KeyError):
        output_kv = hydra(input_, target2=torch.ones(1, 2).long())
    with pytest.raises(KeyError):
        output_kv = hydra(input_,
                          target1=torch.ones(1, 2).long(),
                          target2=torch.ones(1, 2).long())

    output_tuple = hydra.forward_tuple(input_)
    assert len(output_tuple) == 5
    assert (output_tuple[0] == output_kv["features"]).sum().item() == 16
    assert (output_tuple[1] == output_kv["embeddings"]).sum().item() == 16
Example #15
0
class TSN(nn.Module):
    def __init__(self,
                 encoder,
                 num_classes,
                 feature_net_hiddens=None,
                 emb_net_hiddens=None,
                 activation_fn=torch.nn.ReLU,
                 norm_fn=None,
                 bias=True,
                 dropout=None,
                 consensus=None,
                 kernel_size=1,
                 feature_net_skip_connection=False,
                 early_consensus=True):
        super().__init__()

        assert consensus is not None
        assert kernel_size in [1, 3, 5]

        consensus = consensus if isinstance(consensus, list) else [consensus]
        self.consensus = consensus

        self.encoder = encoder
        self.dropout = nn.Dropout(dropout)
        self.feature_net_skip_connection = feature_net_skip_connection
        self.early_consensus = early_consensus

        nonlinearity = registry.MODULES.get_if_str(activation_fn)
        inner_init = create_optimal_inner_init(nonlinearity=nonlinearity)
        kernel2pad = {1: 0, 3: 1, 5: 2}

        def layer_fn(in_features, out_features, bias=True):

            return nn.Conv1d(in_features,
                             out_features,
                             bias=bias,
                             kernel_size=kernel_size,
                             padding=kernel2pad[kernel_size])

        if feature_net_hiddens is not None:
            self.feature_net = SequentialNet(
                hiddens=[encoder.out_features] + [feature_net_hiddens],
                layer_fn=layer_fn,
                norm_fn=norm_fn,
                activation_fn=activation_fn,
            )
            self.feature_net.apply(inner_init)
            out_features = feature_net_hiddens
        else:
            # if no feature net, then no need of skip connection
            # (nothing to skip)
            assert not self.feature_net_skip_connection
            self.feature_net = lambda x: x
            out_features = encoder.out_features

        # Differences are starting here

        # Input channels to consensus function
        # (also to embedding net multiplied by len(consensus))
        if self.feature_net_skip_connection:
            in_channels = out_features + encoder.out_features
        else:
            in_channels = out_features

        consensus_fn = OrderedDict()
        for key in sorted(consensus):
            if key == "attention":
                self.attn = nn.Sequential(
                    nn.Conv1d(in_channels=in_channels,
                              out_channels=1,
                              kernel_size=kernel_size,
                              padding=kernel2pad[kernel_size],
                              bias=True), nn.Softmax(dim=1))

                def self_attn_fn(x):
                    x_a = x.transpose(1, 2)
                    x_attn = (self.attn(x_a) * x_a)
                    x_attn = x_attn.transpose(1, 2)
                    x_attn = x_attn.mean(1, keepdim=True)
                    return x_attn

                consensus_fn["attention"] = self_attn_fn
            elif key == "avg":
                consensus_fn[key] = lambda x: x.mean(1, keepdim=True)
            elif key == "max":
                consensus_fn[key] = lambda x: x.max(1, keepdim=True)[0]

        # Not optimized if too more understandable logic
        if self.early_consensus:
            out_features = emb_net_hiddens

            self.emb_net = SequentialNet(
                hiddens=[in_channels * len(consensus_fn), emb_net_hiddens],
                layer_fn=nn.Linear,
                norm_fn=norm_fn,
                activation_fn=activation_fn,
            )
            self.emb_net.apply(inner_init)
        else:
            if self.feature_net_skip_connection:
                out_features = out_features + self.encoder.out_features
            else:
                out_features = out_features

        self.head = nn.Linear(out_features, num_classes, bias=True)

        if 'attention' in consensus:
            self.attn.apply(outer_init)
        self.head.apply(outer_init)

        self.consensus_fn = consensus_fn

    def forward(self, input):
        if len(input.shape) < 5:
            input = input.unsqueeze(1)
        bs, fl, ch, h, w = input.shape
        x = input.view(-1, ch, h, w)
        x = self.encoder(x)
        x = self.dropout(x)
        identity = x

        # in simple case feature_net is identity mapping
        x = x.view(bs, fl, -1)
        x = x.transpose(1, 2)
        x = self.feature_net(x)
        x = x.transpose(1, 2).contiguous()  # because conv1d
        x = x.view(bs * fl, -1)
        if self.feature_net_skip_connection:
            x = torch.cat([identity, x], dim=-1)
        else:
            x = x

        if self.early_consensus:
            x = x.view(bs, fl, -1)
            c_list = []

            for c_fn in self.consensus_fn.values():
                c_res = c_fn(x)
                c_list.append(c_res)
            x = torch.cat(c_list, dim=1)
            x = x.view(bs, -1)
            x = self.emb_net(x)

        x = self.head(x)

        if not self.early_consensus:
            x = x.view(bs, fl, -1)

            if self.consensus[0] == "avg":
                x = x.mean(1, keepdim=False)
            elif self.consensus[0] == "attention":
                identity = identity.view(bs, fl, -1)
                x_a = identity.transpose(1, 2)
                x_ = x.transpose(1, 2)
                x_attn = (self.attn(x_a) * x_)
                x_attn = x_attn.transpose(1, 2)
                x = x_attn.sum(1, keepdim=False)

        x = torch.sigmoid(x)  # with bce loss
        return x
Example #16
0
def test_config4():
    config_path = Path(__file__).absolute().parent / "config4.yml"
    config4 = utils.load_config(config_path)["model_params"]

    with pytest.raises(AssertionError):
        hydra = Hydra.get_from_params(**config4)
    config4["in_features"] = 16
    hydra = Hydra.get_from_params(**config4)

    config4_ = copy.deepcopy(config4)
    _pop_normalization(config4_)
    heads_params = config4_["heads_params"]
    heads_params["head1"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_hidden"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["head2_1"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_head2_2"]["_hidden"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_head2_2"]["head2_2_1"]["hiddens"].insert(0, 16)

    net = nn.ModuleDict({
        "encoder":
        nn.Sequential(),
        "heads":
        nn.ModuleDict({
            "head1":
            nn.Sequential(
                OrderedDict([
                    ("net", SequentialNet(**heads_params["head1"])),
                ])),
            "_head2":
            nn.ModuleDict({
                "_hidden":
                nn.Sequential(
                    OrderedDict([
                        ("net",
                         SequentialNet(**heads_params["_head2"]["_hidden"]))
                    ])),
                "head2_1":
                nn.Sequential(
                    OrderedDict([
                        ("net",
                         SequentialNet(**heads_params["_head2"]["head2_1"])),
                        ("normalize", Normalize()),
                    ])),
                "_head2_2":
                nn.ModuleDict({
                    "_hidden":
                    nn.Sequential(
                        OrderedDict([("net",
                                      SequentialNet(**heads_params["_head2"]
                                                    ["_head2_2"]["_hidden"]))
                                     ])),
                    "head2_2_1":
                    nn.Sequential(
                        OrderedDict([("net",
                                      SequentialNet(**heads_params["_head2"]
                                                    ["_head2_2"]["head2_2_1"]))
                                     ])),
                })
            })
        })
    })

    _check_named_parameters(hydra.encoder, net["encoder"])
    _check_named_parameters(hydra.heads, net["heads"])
    assert hydra.embedders == {}

    input_ = torch.rand(1, 16)

    output_kv = hydra(input_)
    assert (input_ == output_kv["features"]).sum().item() == 16
    assert (input_ == output_kv["embeddings"]).sum().item() == 16
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    with pytest.raises(KeyError):
        output_kv = hydra(input_, target1=torch.ones(1, 2).long())
    with pytest.raises(KeyError):
        output_kv = hydra(input_, target2=torch.ones(1, 2).long())
    with pytest.raises(KeyError):
        output_kv = hydra(input_,
                          target1=torch.ones(1, 2).long(),
                          target2=torch.ones(1, 2).long())

    output_tuple = hydra.forward_tuple(input_)
    assert len(output_tuple) == 5
    assert (output_tuple[0] == output_kv["features"]).sum().item() == 16
    assert (output_tuple[1] == output_kv["embeddings"]).sum().item() == 16
Example #17
0
def test_config3():
    config_path = Path(__file__).absolute().parent / "config3.yml"
    config3 = utils.load_config(config_path)["model_params"]

    hydra = Hydra.get_from_params(**config3)

    config3_ = copy.deepcopy(config3)
    _pop_normalization(config3_)
    encoder_params = config3_["encoder_params"]
    heads_params = config3_["heads_params"]
    heads_params["head1"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_hidden"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["head2_1"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_head2_2"]["_hidden"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_head2_2"]["head2_2_1"]["hiddens"].insert(0, 16)

    net = nn.ModuleDict({
        "encoder":
        SequentialNet(**encoder_params),
        "embedders":
        nn.ModuleDict({
            "target1":
            nn.Sequential(
                OrderedDict([
                    ("embedding",
                     nn.Embedding(embedding_dim=16, num_embeddings=2)),
                    ("normalize", Normalize()),
                ])),
            "target2":
            nn.Sequential(
                OrderedDict([
                    ("embedding",
                     nn.Embedding(embedding_dim=16, num_embeddings=2)),
                ])),
        }),
        "heads":
        nn.ModuleDict({
            "head1":
            nn.Sequential(
                OrderedDict([
                    ("net", SequentialNet(**heads_params["head1"])),
                ])),
            "_head2":
            nn.ModuleDict({
                "_hidden":
                nn.Sequential(
                    OrderedDict([
                        ("net",
                         SequentialNet(**heads_params["_head2"]["_hidden"]))
                    ])),
                "head2_1":
                nn.Sequential(
                    OrderedDict([
                        ("net",
                         SequentialNet(**heads_params["_head2"]["head2_1"])),
                        ("normalize", Normalize()),
                    ])),
                "_head2_2":
                nn.ModuleDict({
                    "_hidden":
                    nn.Sequential(
                        OrderedDict([("net",
                                      SequentialNet(**heads_params["_head2"]
                                                    ["_head2_2"]["_hidden"]))
                                     ])),
                    "head2_2_1":
                    nn.Sequential(
                        OrderedDict([("net",
                                      SequentialNet(**heads_params["_head2"]
                                                    ["_head2_2"]["head2_2_1"]))
                                     ])),
                })
            })
        })
    })

    _check_named_parameters(hydra.encoder, net["encoder"])
    _check_named_parameters(hydra.heads, net["heads"])
    _check_named_parameters(hydra.embedders, net["embedders"])

    input_ = torch.rand(1, 16)

    output_kv = hydra(input_)
    assert (input_ == output_kv["features"]).sum().item() == 16
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    output_kv = hydra(input_, target1=torch.ones(1, 2).long())
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
        "target1_embeddings",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    output_kv = hydra(input_, target2=torch.ones(1, 2).long())
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
        "target2_embeddings",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    output_kv = hydra(input_,
                      target1=torch.ones(1, 2).long(),
                      target2=torch.ones(1, 2).long())
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
        "target1_embeddings",
        "target2_embeddings",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    output_tuple = hydra.forward_tuple(input_)
    assert len(output_tuple) == 5
    assert (output_tuple[0] == output_kv["features"]).sum().item() == 16
    assert (output_tuple[1] == output_kv["embeddings"]).sum().item() == 16
Example #18
0
def test_config1():
    config1 = {
        "encoder_params": {
            "hiddens": [16, 16],
            "layer_fn": {
                "module": "Linear",
                "bias": False
            },
            "norm_fn": "LayerNorm",
        },
        "heads_params": {
            "head1": {
                "hiddens": [2],
                "layer_fn": {
                    "module": "Linear",
                    "bias": True
                },
            },
            "_head2": {
                "_hidden": {
                    "hiddens": [16],
                    "layer_fn": {
                        "module": "Linear",
                        "bias": False
                    },
                },
                "head2_1": {
                    "hiddens": [32],
                    "layer_fn": {
                        "module": "Linear",
                        "bias": True
                    },
                    "normalize_output": True
                },
                "_head2_2": {
                    "_hidden": {
                        "hiddens": [16, 16, 16],
                        "layer_fn": {
                            "module": "Linear",
                            "bias": False
                        },
                    },
                    "head2_2_1": {
                        "hiddens": [32],
                        "layer_fn": {
                            "module": "Linear",
                            "bias": True
                        },
                        "normalize_output": False,
                    },
                },
            },
        },
        "embedders_params": {
            "target1": {
                "num_embeddings": 2,
                "normalize_output": True,
            },
            "target2": {
                "num_embeddings": 2,
                "normalize_output": False,
            },
        }
    }

    hydra = Hydra.get_from_params(**config1)

    config1_ = copy.deepcopy(config1)
    _pop_normalization(config1_)
    encoder_params = config1_["encoder_params"]
    heads_params = config1_["heads_params"]
    heads_params["head1"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_hidden"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["head2_1"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_head2_2"]["_hidden"]["hiddens"].insert(0, 16)
    heads_params["_head2"]["_head2_2"]["head2_2_1"]["hiddens"].insert(0, 16)

    net = nn.ModuleDict({
        "encoder":
        SequentialNet(**encoder_params),
        "embedders":
        nn.ModuleDict({
            "target1":
            nn.Sequential(
                OrderedDict([
                    ("embedding",
                     nn.Embedding(embedding_dim=16, num_embeddings=2)),
                    ("normalize", Normalize()),
                ])),
            "target2":
            nn.Sequential(
                OrderedDict([
                    ("embedding",
                     nn.Embedding(embedding_dim=16, num_embeddings=2)),
                ])),
        }),
        "heads":
        nn.ModuleDict({
            "head1":
            nn.Sequential(
                OrderedDict([
                    ("net", SequentialNet(**heads_params["head1"])),
                ])),
            "_head2":
            nn.ModuleDict({
                "_hidden":
                nn.Sequential(
                    OrderedDict([
                        ("net",
                         SequentialNet(**heads_params["_head2"]["_hidden"]))
                    ])),
                "head2_1":
                nn.Sequential(
                    OrderedDict([
                        ("net",
                         SequentialNet(**heads_params["_head2"]["head2_1"])),
                        ("normalize", Normalize()),
                    ])),
                "_head2_2":
                nn.ModuleDict({
                    "_hidden":
                    nn.Sequential(
                        OrderedDict([("net",
                                      SequentialNet(**heads_params["_head2"]
                                                    ["_head2_2"]["_hidden"]))
                                     ])),
                    "head2_2_1":
                    nn.Sequential(
                        OrderedDict([("net",
                                      SequentialNet(**heads_params["_head2"]
                                                    ["_head2_2"]["head2_2_1"]))
                                     ])),
                })
            })
        })
    })

    _check_named_parameters(hydra.encoder, net["encoder"])
    _check_named_parameters(hydra.heads, net["heads"])
    _check_named_parameters(hydra.embedders, net["embedders"])

    input_ = torch.rand(1, 16)

    output_kv = hydra(input_)
    assert (input_ == output_kv["features"]).sum().item() == 16
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    output_kv = hydra(input_, target1=torch.ones(1, 2).long())
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
        "target1_embeddings",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    output_kv = hydra(input_, target2=torch.ones(1, 2).long())
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
        "target2_embeddings",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    output_kv = hydra(input_,
                      target1=torch.ones(1, 2).long(),
                      target2=torch.ones(1, 2).long())
    kv_keys = [
        "features",
        "embeddings",
        "head1",
        "_head2/",
        "_head2/head2_1",
        "_head2/_head2_2/",
        "_head2/_head2_2/head2_2_1",
        "target1_embeddings",
        "target2_embeddings",
    ]
    _check_lists(output_kv.keys(), kv_keys)

    output_tuple = hydra.forward_tuple(input_)
    assert len(output_tuple) == 5
    assert (output_tuple[0] == output_kv["features"]).sum().item() == 16
    assert (output_tuple[1] == output_kv["embeddings"]).sum().item() == 16