Пример #1
0
def test_convolution():
    net = get_convolution_net(
        in_channels=3,
        channels=[128, 64, 64],
        kernel_sizes=[8, 4, 3],
        strides=[4, 2, 1],
        groups=[1, 2, 2],
        use_bias=[True, False, False],
        normalization=[None, "BatchNorm2d", "BatchNorm2d"],
        dropout_rate=[None, 0.5, 0.8],
        activation=[None, "ReLU", {
            "module": "ELU",
            "alpha": 0.5
        }],
        residual="soft",
    )

    print(net)
Пример #2
0
def _get_observation_net(state_shape, **observation_net_params):
    # @TODO: make more general and move to contrib
    observation_net_params = deepcopy(observation_net_params)
    observation_net_type = \
        observation_net_params.pop("_network_type", "linear")

    if observation_net_type == "linear":
        # 0 - history len
        observation_size = reduce(lambda x, y: x * y, state_shape[1:])
        observation_net_params["in_features"] = observation_size
        observation_net = get_linear_net(**observation_net_params)
    elif observation_net_type == "convolution":
        # 0 - history len
        observation_net_params["in_channels"] = state_shape[1]
        observation_net = get_convolution_net(**observation_net_params)
    else:
        raise NotImplementedError()

    return observation_net