Ejemplo n.º 1
0
def test_linear():
    net = get_linear_net(
        in_features=32,
        features=[128, 64, 64],
        use_bias=[True, False, False],
        normalization=[None, "BatchNorm1d", "LayerNorm"],
        dropout_rate=[None, 0.5, 0.8],
        activation=[None, "ReLU", {
            "module": "ELU",
            "alpha": 0.5
        }],
        residual="soft",
    )

    print(net)
Ejemplo n.º 2
0
def _get_observation_net(state_shape, **observation_net_params):
    # @TODO: make more general and move to contrib
    observation_net_params = deepcopy(observation_net_params)
    observation_net_type = \
        observation_net_params.pop("_network_type", "linear")

    if observation_net_type == "linear":
        # 0 - history len
        observation_size = reduce(lambda x, y: x * y, state_shape[1:])
        observation_net_params["in_features"] = observation_size
        observation_net = get_linear_net(**observation_net_params)
    elif observation_net_type == "convolution":
        # 0 - history len
        observation_net_params["in_channels"] = state_shape[1]
        observation_net = get_convolution_net(**observation_net_params)
    else:
        raise NotImplementedError()

    return observation_net
Ejemplo n.º 3
0
    def get_from_params(
        cls,
        state_shape,
        observation_net_params=None,
        aggregation_net_params=None,
        main_net_params=None,
    ) -> "StateNet":

        assert main_net_params is not None
        # @TODO: refactor, too complicated; fast&furious development

        main_net_in_features = 0
        observation_net_out_features = 0

        # observation net
        if observation_net_params is not None:
            key_value_flag = observation_net_params.pop("_key_value", False)

            if key_value_flag:
                observation_net = collections.OrderedDict()
                for key in observation_net_params:
                    net_, out_features_ = \
                        utils.get_observation_net(
                            state_shape[key],
                            **observation_net_params[key]
                        )
                    observation_net[key] = net_
                    observation_net_out_features += out_features_
                observation_net = nn.ModuleDict(observation_net)
            else:
                observation_net, observation_net_out_features = \
                    utils.get_observation_net(
                        state_shape,
                        **observation_net_params
                    )
            main_net_in_features += observation_net_out_features
        else:
            observation_net, observation_net_out_features = \
                utils.get_observation_net(state_shape)
            main_net_in_features += observation_net_out_features

        # aggregation net
        if aggregation_net_params is not None:
            aggregation_type = \
                aggregation_net_params.pop("_network_type", "concat")

            if aggregation_type == "concat":
                aggregation_net = TemporalConcatPooling(
                    observation_net_out_features, **aggregation_net_params
                )
            elif aggregation_type == "lama":
                aggregation_net = LamaPooling(
                    observation_net_out_features, **aggregation_net_params
                )
            else:
                raise NotImplementedError()

            main_net_in_features = aggregation_net.out_features
        else:
            aggregation_net = None

        # main net
        main_net_params["in_features"] = main_net_in_features
        main_net = get_linear_net(**main_net_params)

        net = cls(
            main_net=main_net,
            aggregation_net=aggregation_net,
            observation_net=observation_net
        )
        return net