Example #1
0
    def __init__(
        self,
        in_features: int,
        out_features: int,
        policy_type: str = None,
        out_activation: nn.Module = None
    ):
        super().__init__()
        assert policy_type in [
            "categorical", "gauss", "real_nvp", "logits", None
        ]

        # @TODO: refactor
        layer_fn = nn.Linear
        activation_fn = nn.ReLU
        squashing_fn = nn.Tanh
        bias = True

        if policy_type == "categorical":
            head_size = out_features
            policy_net = CategoricalPolicy()
        elif policy_type == "gauss":
            head_size = out_features * 2
            policy_net = GaussPolicy(squashing_fn)
        elif policy_type == "real_nvp":
            head_size = out_features * 2
            policy_net = RealNVPPolicy(
                action_size=out_features,
                layer_fn=layer_fn,
                activation_fn=activation_fn,
                squashing_fn=squashing_fn,
                bias=bias
            )
        else:
            head_size = out_features
            policy_net = None
            policy_type = "logits"

        self.policy_type = policy_type

        head_net = SequentialNet(
            hiddens=[in_features, head_size],
            layer_fn=nn.Linear,
            activation_fn=out_activation,
            norm_fn=None,
            bias=True
        )
        head_net.apply(outer_init)
        self.head_net = head_net

        self.policy_net = policy_net
        self._policy_fn = None
        if policy_net is None:
            self._policy_fn = lambda *args: args[0]
        elif isinstance(
            policy_net, (CategoricalPolicy, GaussPolicy, RealNVPPolicy)
        ):
            self._policy_fn = policy_net.forward
        else:
            raise NotImplementedError
Example #2
0
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 policy_type: str = None,
                 out_activation: nn.Module = None):
        super().__init__()
        assert policy_type in [
            "categorical", "bernoulli", "diagonal-gauss", "squashing-gauss",
            "real-nvp", "logits", None
        ]

        # @TODO: refactor
        layer_fn = nn.Linear
        activation_fn = nn.ReLU
        squashing_fn = out_activation
        bias = True

        if policy_type == "categorical":
            assert out_activation is None
            head_size = out_features
            policy_net = CategoricalPolicy()
        elif policy_type == "bernoulli":
            assert out_activation is None
            head_size = out_features
            policy_net = BernoulliPolicy()
        elif policy_type == "diagonal-gauss":
            head_size = out_features * 2
            policy_net = DiagonalGaussPolicy()
        elif policy_type == "squashing-gauss":
            out_activation = None
            head_size = out_features * 2
            policy_net = SquashingGaussPolicy(squashing_fn)
        elif policy_type == "real-nvp":
            out_activation = None
            head_size = out_features * 2
            policy_net = RealNVPPolicy(action_size=out_features,
                                       layer_fn=layer_fn,
                                       activation_fn=activation_fn,
                                       squashing_fn=squashing_fn,
                                       bias=bias)
        else:
            head_size = out_features
            policy_net = None
            policy_type = "logits"

        self.policy_type = policy_type

        head_net = SequentialNet(
            hiddens=[in_features, head_size],
            layer_fn={
                "module": layer_fn,
                "bias": True
            },
            activation_fn=out_activation,
            norm_fn=None,
        )
        head_net.apply(outer_init)
        self.head_net = head_net

        self.policy_net = policy_net
        self._policy_fn = None
        if policy_net is not None:
            self._policy_fn = policy_net.forward
        else:
            self._policy_fn = lambda *args: args[0]
Example #3
0
    def create_from_params(
        cls,
        state_shape,
        action_size,
        observation_hiddens=None,
        head_hiddens=None,
        layer_fn=nn.Linear,
        activation_fn=nn.ReLU,
        dropout=None,
        norm_fn=None,
        bias=True,
        layer_order=None,
        residual=False,
        out_activation=None,
        observation_aggregation=None,
        lama_poolings=None,
        policy_type=None,
        squashing_fn=nn.Tanh,
        **kwargs
    ):
        assert len(kwargs) == 0

        observation_hiddens = observation_hiddens or []
        head_hiddens = head_hiddens or []

        layer_fn = MODULES.get_if_str(layer_fn)
        activation_fn = MODULES.get_if_str(activation_fn)
        norm_fn = MODULES.get_if_str(norm_fn)
        out_activation = MODULES.get_if_str(out_activation)
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

        if isinstance(state_shape, int):
            state_shape = (state_shape,)

        if len(state_shape) in [1, 2]:
            # linear case: one observation or several one
            # state_shape like [history_len, obs_shape]
            # @TODO: handle lama/rnn correctly
            if not observation_aggregation:
                observation_size = reduce(lambda x, y: x * y, state_shape)
            else:
                observation_size = reduce(lambda x, y: x * y, state_shape[1:])

            if len(observation_hiddens) > 0:
                observation_net = SequentialNet(
                    hiddens=[observation_size] + observation_hiddens,
                    layer_fn=layer_fn,
                    dropout=dropout,
                    activation_fn=activation_fn,
                    norm_fn=norm_fn,
                    bias=bias,
                    layer_order=layer_order,
                    residual=residual
                )
                observation_net.apply(inner_init)
                obs_out = observation_hiddens[-1]
            else:
                observation_net = None
                obs_out = observation_size

        elif len(state_shape) in [3, 4]:
            # cnn case: one image or several one @TODO
            raise NotImplementedError
        else:
            raise NotImplementedError

        assert obs_out

        if observation_aggregation == "lama_obs":
            aggregation_net = LamaPooling(
                features_in=obs_out,
                poolings=lama_poolings
            )
            aggregation_out = aggregation_net.features_out
        else:
            aggregation_net = None
            aggregation_out = obs_out

        main_net = SequentialNet(
            hiddens=[aggregation_out] + head_hiddens,
            layer_fn=layer_fn,
            dropout=dropout,
            activation_fn=activation_fn,
            norm_fn=norm_fn,
            bias=bias,
            layer_order=layer_order,
            residual=residual
        )
        main_net.apply(inner_init)

        # @TODO: place for memory network

        if policy_type == "gauss":
            head_size = action_size * 2
            policy_net = GaussPolicy(squashing_fn)
        elif policy_type == "real_nvp":
            head_size = action_size * 2
            policy_net = RealNVPPolicy(
                action_size=action_size,
                layer_fn=layer_fn,
                activation_fn=activation_fn,
                squashing_fn=squashing_fn,
                norm_fn=None,
                bias=bias
            )
        else:
            head_size = action_size
            policy_net = None

        head_net = SequentialNet(
            hiddens=[head_hiddens[-1], head_size],
            layer_fn=nn.Linear,
            activation_fn=out_activation,
            norm_fn=None,
            bias=True
        )
        head_net.apply(outer_init)

        actor_net = cls(
            observation_net=observation_net,
            aggregation_net=aggregation_net,
            main_net=main_net,
            head_net=head_net,
            policy_net=policy_net
        )

        return actor_net
Example #4
0
class CouplingLayer(nn.Module):
    def __init__(self,
                 action_size,
                 layer_fn,
                 activation_fn=nn.ReLU,
                 bias=True,
                 parity="odd"):
        """
        Conditional affine coupling layer used in Real NVP Bijector.
        Original paper: https://arxiv.org/abs/1605.08803
        Adaptation to RL: https://arxiv.org/abs/1804.02808
        Important notes
        ---------------
        1. State embeddings are supposed to have size (action_size * 2).
        2. Scale and translation networks used in the Real NVP Bijector
        both have one hidden layer of (action_size) (activation_fn) units.
        3. Parity ("odd" or "even") determines which part of the input
        is being copied and which is being transformed.
        """
        super().__init__()

        layer_fn = MODULES.get_if_str(layer_fn)
        activation_fn = MODULES.get_if_str(activation_fn)

        self.parity = parity
        if self.parity == "odd":
            self.copy_size = action_size // 2
        else:
            self.copy_size = action_size - action_size // 2

        self.scale_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn=layer_fn,
            activation_fn=activation_fn,
            norm_fn=None,
            bias=bias)
        self.scale_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn=layer_fn,
            activation_fn=None,
            norm_fn=None,
            bias=True)

        self.translation_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn=layer_fn,
            activation_fn=activation_fn,
            norm_fn=None,
            bias=bias)
        self.translation_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn=layer_fn,
            activation_fn=None,
            norm_fn=None,
            bias=True)

        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)
        self.scale_prenet.apply(inner_init)
        self.scale_net.apply(outer_init)
        self.translation_prenet.apply(inner_init)
        self.translation_net.apply(outer_init)

    def forward(self, action, state_embedding, log_pi):
        if self.parity == "odd":
            action_copy = action[:, :self.copy_size]
            action_transform = action[:, self.copy_size:]
        else:
            action_copy = action[:, -self.copy_size:]
            action_transform = action[:, :-self.copy_size]

        x = torch.cat((state_embedding, action_copy), dim=1)

        t = self.translation_prenet(x)
        t = self.translation_net(t)

        s = self.scale_prenet(x)
        s = self.scale_net(s)

        out_transform = t + action_transform * torch.exp(s)

        if self.parity == "odd":
            action = torch.cat((action_copy, out_transform), dim=1)
        else:
            action = torch.cat((out_transform, action_copy), dim=1)

        log_det_jacobian = s.sum(dim=1)
        log_pi = log_pi - log_det_jacobian

        return action, log_pi
Example #5
0
    def create_from_params(cls,
                           state_shape,
                           observation_hiddens=None,
                           head_hiddens=None,
                           layer_fn=nn.Linear,
                           activation_fn=nn.ReLU,
                           dropout=None,
                           norm_fn=None,
                           bias=True,
                           layer_order=None,
                           residual=False,
                           out_activation=None,
                           history_aggregation_type=None,
                           lama_poolings=None,
                           **kwargs):
        assert len(kwargs) == 0
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        observation_hiddens = observation_hiddens or []
        head_hiddens = head_hiddens or []

        layer_fn = Registry.name2nn(layer_fn)
        activation_fn = Registry.name2nn(activation_fn)
        norm_fn = Registry.name2nn(norm_fn)
        out_activation = Registry.name2nn(out_activation)
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

        if isinstance(state_shape, int):
            state_shape = (state_shape, )

        if len(state_shape) in [1, 2]:
            # linear case: one observation or several one
            # state_shape like [history_len, obs_shape]
            # @TODO: handle lama/rnn correctly
            if not history_aggregation_type:
                state_size = reduce(lambda x, y: x * y, state_shape)
            else:
                state_size = reduce(lambda x, y: x * y, state_shape[1:])

            if len(observation_hiddens) > 0:
                observation_net = SequentialNet(hiddens=[state_size] +
                                                observation_hiddens,
                                                layer_fn=layer_fn,
                                                dropout=dropout,
                                                activation_fn=activation_fn,
                                                norm_fn=norm_fn,
                                                bias=bias,
                                                layer_order=layer_order,
                                                residual=residual)
                observation_net.apply(inner_init)
                obs_out = observation_hiddens[-1]
            else:
                observation_net = None
                obs_out = state_size

        elif len(state_shape) in [3, 4]:
            # cnn case: one image or several one @TODO
            raise NotImplementedError
        else:
            raise NotImplementedError

        assert obs_out

        if history_aggregation_type == "lama_obs":
            aggregation_net = LamaPooling(features_in=obs_out,
                                          poolings=lama_poolings)
            aggregation_out = aggregation_net.features_out
        else:
            aggregation_net = None
            aggregation_out = obs_out

        main_net = SequentialNet(hiddens=[aggregation_out] + head_hiddens[:-1],
                                 layer_fn=layer_fn,
                                 dropout=dropout,
                                 activation_fn=activation_fn,
                                 norm_fn=norm_fn,
                                 bias=bias,
                                 layer_order=layer_order,
                                 residual=residual)
        main_net.apply(inner_init)

        # @TODO: place for memory network

        head_net = SequentialNet(hiddens=[head_hiddens[-2], head_hiddens[-1]],
                                 layer_fn=nn.Linear,
                                 activation_fn=out_activation,
                                 norm_fn=None,
                                 bias=True)
        head_net.apply(outer_init)

        critic_net = cls(observation_net=observation_net,
                         aggregation_net=aggregation_net,
                         main_net=main_net,
                         head_net=head_net,
                         policy_net=None)

        return critic_net
Example #6
0
class TSN(nn.Module):
    def __init__(self,
                 encoder,
                 num_classes,
                 feature_net_hiddens=None,
                 emb_net_hiddens=None,
                 activation_fn=torch.nn.ReLU,
                 norm_fn=None,
                 bias=True,
                 dropout=None,
                 consensus=None,
                 kernel_size=1,
                 feature_net_skip_connection=False,
                 early_consensus=True):
        super().__init__()

        assert consensus is not None
        assert kernel_size in [1, 3, 5]

        consensus = consensus if isinstance(consensus, list) else [consensus]
        self.consensus = consensus

        self.encoder = encoder
        self.dropout = nn.Dropout(dropout)
        self.feature_net_skip_connection = feature_net_skip_connection
        self.early_consensus = early_consensus

        nonlinearity = registry.MODULES.get_if_str(activation_fn)
        inner_init = create_optimal_inner_init(nonlinearity=nonlinearity)
        kernel2pad = {1: 0, 3: 1, 5: 2}

        def layer_fn(in_features, out_features, bias=True):

            return nn.Conv1d(in_features,
                             out_features,
                             bias=bias,
                             kernel_size=kernel_size,
                             padding=kernel2pad[kernel_size])

        if feature_net_hiddens is not None:
            self.feature_net = SequentialNet(
                hiddens=[encoder.out_features] + [feature_net_hiddens],
                layer_fn=layer_fn,
                norm_fn=norm_fn,
                activation_fn=activation_fn,
            )
            self.feature_net.apply(inner_init)
            out_features = feature_net_hiddens
        else:
            # if no feature net, then no need of skip connection
            # (nothing to skip)
            assert not self.feature_net_skip_connection
            self.feature_net = lambda x: x
            out_features = encoder.out_features

        # Differences are starting here

        # Input channels to consensus function
        # (also to embedding net multiplied by len(consensus))
        if self.feature_net_skip_connection:
            in_channels = out_features + encoder.out_features
        else:
            in_channels = out_features

        consensus_fn = OrderedDict()
        for key in sorted(consensus):
            if key == "attention":
                self.attn = nn.Sequential(
                    nn.Conv1d(in_channels=in_channels,
                              out_channels=1,
                              kernel_size=kernel_size,
                              padding=kernel2pad[kernel_size],
                              bias=True), nn.Softmax(dim=1))

                def self_attn_fn(x):
                    x_a = x.transpose(1, 2)
                    x_attn = (self.attn(x_a) * x_a)
                    x_attn = x_attn.transpose(1, 2)
                    x_attn = x_attn.mean(1, keepdim=True)
                    return x_attn

                consensus_fn["attention"] = self_attn_fn
            elif key == "avg":
                consensus_fn[key] = lambda x: x.mean(1, keepdim=True)
            elif key == "max":
                consensus_fn[key] = lambda x: x.max(1, keepdim=True)[0]

        # Not optimized if too more understandable logic
        if self.early_consensus:
            out_features = emb_net_hiddens

            self.emb_net = SequentialNet(
                hiddens=[in_channels * len(consensus_fn), emb_net_hiddens],
                layer_fn=nn.Linear,
                norm_fn=norm_fn,
                activation_fn=activation_fn,
            )
            self.emb_net.apply(inner_init)
        else:
            if self.feature_net_skip_connection:
                out_features = out_features + self.encoder.out_features
            else:
                out_features = out_features

        self.head = nn.Linear(out_features, num_classes, bias=True)

        if 'attention' in consensus:
            self.attn.apply(outer_init)
        self.head.apply(outer_init)

        self.consensus_fn = consensus_fn

    def forward(self, input):
        if len(input.shape) < 5:
            input = input.unsqueeze(1)
        bs, fl, ch, h, w = input.shape
        x = input.view(-1, ch, h, w)
        x = self.encoder(x)
        x = self.dropout(x)
        identity = x

        # in simple case feature_net is identity mapping
        x = x.view(bs, fl, -1)
        x = x.transpose(1, 2)
        x = self.feature_net(x)
        x = x.transpose(1, 2).contiguous()  # because conv1d
        x = x.view(bs * fl, -1)
        if self.feature_net_skip_connection:
            x = torch.cat([identity, x], dim=-1)
        else:
            x = x

        if self.early_consensus:
            x = x.view(bs, fl, -1)
            c_list = []

            for c_fn in self.consensus_fn.values():
                c_res = c_fn(x)
                c_list.append(c_res)
            x = torch.cat(c_list, dim=1)
            x = x.view(bs, -1)
            x = self.emb_net(x)

        x = self.head(x)

        if not self.early_consensus:
            x = x.view(bs, fl, -1)

            if self.consensus[0] == "avg":
                x = x.mean(1, keepdim=False)
            elif self.consensus[0] == "attention":
                identity = identity.view(bs, fl, -1)
                x_a = identity.transpose(1, 2)
                x_ = x.transpose(1, 2)
                x_attn = (self.attn(x_a) * x_)
                x_attn = x_attn.transpose(1, 2)
                x = x_attn.sum(1, keepdim=False)

        x = torch.sigmoid(x)  # with bce loss
        return x