Example #1
0
    def __init__(
        self,
        hiddens,
        layer_fn=nn.Linear,
        bias=True,
        norm_fn=None,
        activation_fn=nn.ReLU,
        dropout=None,
        layer_order=None,
        residual=False
    ):
        super().__init__()
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        layer_fn = Registry.name2nn(layer_fn)
        activation_fn = Registry.name2nn(activation_fn)
        norm_fn = Registry.name2nn(norm_fn)
        dropout = Registry.name2nn(dropout)

        layer_order = layer_order or ["layer", "norm", "drop", "act"]

        if isinstance(dropout, float):
            dropout_fn = lambda: nn.Dropout(dropout)
        else:
            dropout_fn = dropout

        def _layer_fn(f_in, f_out, bias):
            return layer_fn(f_in, f_out, bias=bias)

        def _normalize_fn(f_in, f_out, bias):
            return norm_fn(f_out) if norm_fn is not None else None

        def _dropout_fn(f_in, f_out, bias):
            return dropout_fn() if dropout_fn is not None else None

        def _activation_fn(f_in, f_out, bias):
            return activation_fn() if activation_fn is not None else None

        name2fn = {
            "layer": _layer_fn,
            "norm": _normalize_fn,
            "drop": _dropout_fn,
            "act": _activation_fn,
        }

        net = []

        for i, (f_in, f_out) in enumerate(pairwise(hiddens)):
            block = []
            for key in layer_order:
                fn = name2fn[key](f_in, f_out, bias)
                if fn is not None:
                    block.append((f"{key}", fn))
            block = torch.nn.Sequential(OrderedDict(block))
            if residual:
                block = ResidualWrapper(net=block)
            net.append((f"block_{i}", block))

        self.net = torch.nn.Sequential(OrderedDict(net))
Example #2
0
    def __init__(self,
                 action_size,
                 layer_fn,
                 activation_fn=nn.ReLU,
                 squashing_fn=nn.Tanh,
                 norm_fn=None,
                 bias=False):
        super().__init__()
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry
        activation_fn = Registry.name2nn(activation_fn)
        self.action_size = action_size

        self.coupling1 = CouplingLayer(action_size=action_size,
                                       layer_fn=layer_fn,
                                       activation_fn=activation_fn,
                                       norm_fn=None,
                                       bias=bias,
                                       parity="odd")
        self.coupling2 = CouplingLayer(action_size=action_size,
                                       layer_fn=layer_fn,
                                       activation_fn=activation_fn,
                                       norm_fn=None,
                                       bias=bias,
                                       parity="even")
        self.squashing_layer = SquashingLayer(squashing_fn)
Example #3
0
 def __init__(self, squashing_fn=nn.Tanh):
     """ Layer that squashes samples from some distribution to be bounded.
     """
     super().__init__()
     # hack to prevent cycle imports
     from catalyst.contrib.registry import Registry
     self.squashing_fn = Registry.name2nn(squashing_fn)()
Example #4
0
 def __init__(self, in_features, activation_fn="Tanh"):
     super().__init__()
     # hack to prevent cycle imports
     from catalyst.contrib.registry import Registry
     activation_fn = Registry.name2nn(activation_fn)
     self.attn = nn.Sequential(
         nn.Conv2d(
             in_features, 1, kernel_size=1, stride=1, padding=0, bias=False
         ), activation_fn()
     )
Example #5
0
    def __init__(self,
                 arch="resnet34",
                 pretrained=True,
                 frozen=True,
                 pooling=None,
                 pooling_kwargs=None,
                 cut_layers=2):
        super().__init__()
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        resnet = torchvision.models.__dict__[arch](pretrained=pretrained)
        modules = list(resnet.children())[:-cut_layers]  # delete last layers

        if frozen:
            for module in modules:
                for param in module.parameters():
                    param.requires_grad = False

        if pooling is not None:
            pooling_kwargs = pooling_kwargs or {}
            pooling_layer_fn = Registry.name2nn(pooling)
            pooling_layer = pooling_layer_fn(
                in_features=resnet.fc.in_features, **pooling_kwargs) \
                if "attn" in pooling.lower() \
                else pooling_layer_fn(**pooling_kwargs)
            modules += [pooling_layer]

            out_features = pooling_layer.out_features(
                in_features=resnet.fc.in_features)
        else:
            out_features = resnet.fc.in_features

        flatten = Registry.name2nn("Flatten")
        modules += [flatten()]
        self.out_features = out_features

        self.encoder = nn.Sequential(*modules)
Example #6
0
    def __init__(
        self,
        input_size=224,
        width_mult=1.,
        pretrained=True,
        pooling=None,
        pooling_kwargs=None,
    ):
        super().__init__()
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        net = MobileNetV2(input_size=input_size,
                          width_mult=width_mult,
                          pretrained=pretrained)
        self.encoder = list(net.encoder.children())

        if pooling is not None:
            pooling_kwargs = pooling_kwargs or {}
            pooling_layer_fn = Registry.name2nn(pooling)
            pooling_layer = pooling_layer_fn(
                in_features=self.last_channel, **pooling_kwargs) \
                if "attn" in pooling.lower() \
                else pooling_layer_fn(**pooling_kwargs)
            self.encoder.append(pooling_layer)

            out_features = pooling_layer.out_features(
                in_features=net.output_channel)
        else:
            out_features = net.output_channel

        self.out_features = out_features
        # make it nn.Sequential
        self.encoder = nn.Sequential(*self.encoder)

        self._initialize_weights()
Example #7
0
    def create_from_params(cls,
                           state_shape,
                           action_size,
                           observation_hiddens=None,
                           head_hiddens=None,
                           layer_fn=nn.Linear,
                           activation_fn=nn.ReLU,
                           dropout=None,
                           norm_fn=None,
                           bias=True,
                           layer_order=None,
                           residual=False,
                           out_activation=None,
                           observation_aggregation=None,
                           lama_poolings=None,
                           policy_type=None,
                           squashing_fn=nn.Tanh,
                           **kwargs):
        assert len(kwargs) == 0
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        observation_hiddens = observation_hiddens or []
        head_hiddens = head_hiddens or []

        layer_fn = Registry.name2nn(layer_fn)
        activation_fn = Registry.name2nn(activation_fn)
        norm_fn = Registry.name2nn(norm_fn)
        out_activation = Registry.name2nn(out_activation)
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

        if isinstance(state_shape, int):
            state_shape = (state_shape, )

        if len(state_shape) in [1, 2]:
            # linear case: one observation or several one
            # state_shape like [history_len, obs_shape]
            # @TODO: handle lama/rnn correctly
            if not observation_aggregation:
                observation_size = reduce(lambda x, y: x * y, state_shape)
            else:
                observation_size = reduce(lambda x, y: x * y, state_shape[1:])

            if len(observation_hiddens) > 0:
                observation_net = SequentialNet(hiddens=[observation_size] +
                                                observation_hiddens,
                                                layer_fn=layer_fn,
                                                dropout=dropout,
                                                activation_fn=activation_fn,
                                                norm_fn=norm_fn,
                                                bias=bias,
                                                layer_order=layer_order,
                                                residual=residual)
                observation_net.apply(inner_init)
                obs_out = observation_hiddens[-1]
            else:
                observation_net = None
                obs_out = observation_size

        elif len(state_shape) in [3, 4]:
            # cnn case: one image or several one @TODO
            raise NotImplementedError
        else:
            raise NotImplementedError

        assert obs_out

        if observation_aggregation == "lama_obs":
            aggregation_net = LamaPooling(features_in=obs_out,
                                          poolings=lama_poolings)
            aggregation_out = aggregation_net.features_out
        else:
            aggregation_net = None
            aggregation_out = obs_out

        main_net = SequentialNet(hiddens=[aggregation_out] + head_hiddens,
                                 layer_fn=layer_fn,
                                 dropout=dropout,
                                 activation_fn=activation_fn,
                                 norm_fn=norm_fn,
                                 bias=bias,
                                 layer_order=layer_order,
                                 residual=residual)
        main_net.apply(inner_init)

        # @TODO: place for memory network

        if policy_type == "gauss":
            head_size = action_size * 2
            policy_net = GaussPolicy(squashing_fn)
        elif policy_type == "real_nvp":
            head_size = action_size * 2
            policy_net = RealNVPPolicy(action_size=action_size,
                                       layer_fn=layer_fn,
                                       activation_fn=activation_fn,
                                       squashing_fn=squashing_fn,
                                       norm_fn=None,
                                       bias=bias)
        else:
            head_size = action_size
            policy_net = None

        head_net = SequentialNet(hiddens=[head_hiddens[-1], head_size],
                                 layer_fn=nn.Linear,
                                 activation_fn=out_activation,
                                 norm_fn=None,
                                 bias=True)
        head_net.apply(outer_init)

        actor_net = cls(observation_net=observation_net,
                        aggregation_net=aggregation_net,
                        main_net=main_net,
                        head_net=head_net,
                        policy_net=policy_net)

        return actor_net
Example #8
0
    def __init__(self,
                 action_size,
                 layer_fn,
                 activation_fn=nn.ReLU,
                 norm_fn=None,
                 bias=True,
                 parity="odd"):
        """ Conditional affine coupling layer used in Real NVP Bijector.
        Original paper: https://arxiv.org/abs/1605.08803
        Adaptation to RL: https://arxiv.org/abs/1804.02808
        Important notes
        ---------------
        1. State embeddings are supposed to have size (action_size * 2).
        2. Scale and translation networks used in the Real NVP Bijector
        both have one hidden layer of (action_size) (activation_fn) units.
        3. Parity ("odd" or "even") determines which part of the input
        is being copied and which is being transformed.
        """
        super().__init__()
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        layer_fn = Registry.name2nn(layer_fn)
        activation_fn = Registry.name2nn(activation_fn)
        norm_fn = Registry.name2nn(norm_fn)

        self.parity = parity
        if self.parity == "odd":
            self.copy_size = action_size // 2
        else:
            self.copy_size = action_size - action_size // 2

        self.scale_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn=layer_fn,
            activation_fn=activation_fn,
            norm_fn=None,
            bias=bias)
        self.scale_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn=layer_fn,
            activation_fn=None,
            norm_fn=None,
            bias=True)

        self.translation_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn=layer_fn,
            activation_fn=activation_fn,
            norm_fn=None,
            bias=bias)
        self.translation_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn=layer_fn,
            activation_fn=None,
            norm_fn=None,
            bias=True)

        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)
        self.scale_prenet.apply(inner_init)
        self.scale_net.apply(outer_init)
        self.translation_prenet.apply(inner_init)
        self.translation_net.apply(outer_init)