Пример #1
0
    def create_from_params(
        cls,
        state_shape,
        action_size,
        observation_hiddens=None,
        head_hiddens=None,
        layer_fn=nn.Linear,
        activation_fn=nn.ReLU,
        dropout=None,
        norm_fn=None,
        bias=True,
        layer_order=None,
        residual=False,
        out_activation=None,
        observation_aggregation=None,
        lama_poolings=None,
        policy_type=None,
        squashing_fn=nn.Tanh,
        **kwargs
    ):
        assert len(kwargs) == 0

        observation_hiddens = observation_hiddens or []
        head_hiddens = head_hiddens or []

        layer_fn = MODULES.get_if_str(layer_fn)
        activation_fn = MODULES.get_if_str(activation_fn)
        norm_fn = MODULES.get_if_str(norm_fn)
        out_activation = MODULES.get_if_str(out_activation)
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

        if isinstance(state_shape, int):
            state_shape = (state_shape,)

        if len(state_shape) in [1, 2]:
            # linear case: one observation or several one
            # state_shape like [history_len, obs_shape]
            # @TODO: handle lama/rnn correctly
            if not observation_aggregation:
                observation_size = reduce(lambda x, y: x * y, state_shape)
            else:
                observation_size = reduce(lambda x, y: x * y, state_shape[1:])

            if len(observation_hiddens) > 0:
                observation_net = SequentialNet(
                    hiddens=[observation_size] + observation_hiddens,
                    layer_fn=layer_fn,
                    dropout=dropout,
                    activation_fn=activation_fn,
                    norm_fn=norm_fn,
                    bias=bias,
                    layer_order=layer_order,
                    residual=residual
                )
                observation_net.apply(inner_init)
                obs_out = observation_hiddens[-1]
            else:
                observation_net = None
                obs_out = observation_size

        elif len(state_shape) in [3, 4]:
            # cnn case: one image or several one @TODO
            raise NotImplementedError
        else:
            raise NotImplementedError

        assert obs_out

        if observation_aggregation == "lama_obs":
            aggregation_net = LamaPooling(
                features_in=obs_out,
                poolings=lama_poolings
            )
            aggregation_out = aggregation_net.features_out
        else:
            aggregation_net = None
            aggregation_out = obs_out

        main_net = SequentialNet(
            hiddens=[aggregation_out] + head_hiddens,
            layer_fn=layer_fn,
            dropout=dropout,
            activation_fn=activation_fn,
            norm_fn=norm_fn,
            bias=bias,
            layer_order=layer_order,
            residual=residual
        )
        main_net.apply(inner_init)

        # @TODO: place for memory network

        if policy_type == "gauss":
            head_size = action_size * 2
            policy_net = GaussPolicy(squashing_fn)
        elif policy_type == "real_nvp":
            head_size = action_size * 2
            policy_net = RealNVPPolicy(
                action_size=action_size,
                layer_fn=layer_fn,
                activation_fn=activation_fn,
                squashing_fn=squashing_fn,
                norm_fn=None,
                bias=bias
            )
        else:
            head_size = action_size
            policy_net = None

        head_net = SequentialNet(
            hiddens=[head_hiddens[-1], head_size],
            layer_fn=nn.Linear,
            activation_fn=out_activation,
            norm_fn=None,
            bias=True
        )
        head_net.apply(outer_init)

        actor_net = cls(
            observation_net=observation_net,
            aggregation_net=aggregation_net,
            main_net=main_net,
            head_net=head_net,
            policy_net=policy_net
        )

        return actor_net
Пример #2
0
    def create_from_params(cls,
                           state_shape,
                           observation_hiddens=None,
                           head_hiddens=None,
                           layer_fn=nn.Linear,
                           activation_fn=nn.ReLU,
                           dropout=None,
                           norm_fn=None,
                           bias=True,
                           layer_order=None,
                           residual=False,
                           out_activation=None,
                           history_aggregation_type=None,
                           lama_poolings=None,
                           **kwargs):
        assert len(kwargs) == 0
        # hack to prevent cycle imports
        from catalyst.contrib.registry import Registry

        observation_hiddens = observation_hiddens or []
        head_hiddens = head_hiddens or []

        layer_fn = Registry.name2nn(layer_fn)
        activation_fn = Registry.name2nn(activation_fn)
        norm_fn = Registry.name2nn(norm_fn)
        out_activation = Registry.name2nn(out_activation)
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

        if isinstance(state_shape, int):
            state_shape = (state_shape, )

        if len(state_shape) in [1, 2]:
            # linear case: one observation or several one
            # state_shape like [history_len, obs_shape]
            # @TODO: handle lama/rnn correctly
            if not history_aggregation_type:
                state_size = reduce(lambda x, y: x * y, state_shape)
            else:
                state_size = reduce(lambda x, y: x * y, state_shape[1:])

            if len(observation_hiddens) > 0:
                observation_net = SequentialNet(hiddens=[state_size] +
                                                observation_hiddens,
                                                layer_fn=layer_fn,
                                                dropout=dropout,
                                                activation_fn=activation_fn,
                                                norm_fn=norm_fn,
                                                bias=bias,
                                                layer_order=layer_order,
                                                residual=residual)
                observation_net.apply(inner_init)
                obs_out = observation_hiddens[-1]
            else:
                observation_net = None
                obs_out = state_size

        elif len(state_shape) in [3, 4]:
            # cnn case: one image or several one @TODO
            raise NotImplementedError
        else:
            raise NotImplementedError

        assert obs_out

        if history_aggregation_type == "lama_obs":
            aggregation_net = LamaPooling(features_in=obs_out,
                                          poolings=lama_poolings)
            aggregation_out = aggregation_net.features_out
        else:
            aggregation_net = None
            aggregation_out = obs_out

        main_net = SequentialNet(hiddens=[aggregation_out] + head_hiddens[:-1],
                                 layer_fn=layer_fn,
                                 dropout=dropout,
                                 activation_fn=activation_fn,
                                 norm_fn=norm_fn,
                                 bias=bias,
                                 layer_order=layer_order,
                                 residual=residual)
        main_net.apply(inner_init)

        # @TODO: place for memory network

        head_net = SequentialNet(hiddens=[head_hiddens[-2], head_hiddens[-1]],
                                 layer_fn=nn.Linear,
                                 activation_fn=out_activation,
                                 norm_fn=None,
                                 bias=True)
        head_net.apply(outer_init)

        critic_net = cls(observation_net=observation_net,
                         aggregation_net=aggregation_net,
                         main_net=main_net,
                         head_net=head_net,
                         policy_net=None)

        return critic_net