コード例 #1
0
    def __init__(
        self,
        observation_shapes: List[Tuple[int, ...]],
        network_settings: NetworkSettings,
        encoded_act_size: int = 0,
    ):
        super().__init__()
        self.normalize = network_settings.normalize
        self.use_lstm = network_settings.memory is not None
        self.h_size = network_settings.hidden_units
        self.m_size = (
            network_settings.memory.memory_size
            if network_settings.memory is not None
            else 0
        )

        self.visual_encoders, self.vector_encoders = ModelUtils.create_encoders(
            observation_shapes,
            self.h_size,
            network_settings.num_layers,
            network_settings.vis_encode_type,
            unnormalized_inputs=encoded_act_size,
            normalize=self.normalize,
        )

        if self.use_lstm:
            self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True)
        else:
            self.lstm = None
コード例 #2
0
def test_create_encoders(encoder_type, normalize, num_vector, num_visual,
                         unnormalized_inputs):
    vec_obs_shape = (5, )
    vis_obs_shape = (84, 84, 3)
    obs_shapes = []
    for _ in range(num_vector):
        obs_shapes.append(vec_obs_shape)
    for _ in range(num_visual):
        obs_shapes.append(vis_obs_shape)
    h_size = 128
    num_layers = 3
    unnormalized_inputs = 1
    vis_enc, vec_enc = ModelUtils.create_encoders(obs_shapes, h_size,
                                                  num_layers, encoder_type,
                                                  unnormalized_inputs,
                                                  normalize)
    vec_enc = list(vec_enc)
    vis_enc = list(vis_enc)
    assert len(vec_enc) == (1 if unnormalized_inputs + num_vector > 0 else 0
                            )  # There's always at most one vector encoder.
    assert len(vis_enc) == num_visual

    if unnormalized_inputs > 0:
        assert isinstance(vec_enc[0], VectorAndUnnormalizedInputEncoder)
    elif num_vector > 0:
        assert isinstance(vec_enc[0], VectorEncoder)

    for enc in vis_enc:
        assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type))