Exemplo n.º 1
0
def test_valueheads():
    stream_names = [f"reward_signal_{num}" for num in range(5)]
    input_size = 5
    batch_size = 4

    # Test default 1 value per head
    value_heads = ValueHeads(stream_names, input_size)
    input_data = torch.ones((batch_size, input_size))
    value_out = value_heads(
        input_data)  # Note: mean value will be removed shortly

    for stream_name in stream_names:
        assert value_out[stream_name].shape == (batch_size, )

    # Test that inputting the wrong size input will throw an error
    with pytest.raises(Exception):
        value_out = value_heads(torch.ones((batch_size, input_size + 2)))

    # Test multiple values per head (e.g. discrete Q function)
    output_size = 4
    value_heads = ValueHeads(stream_names, input_size, output_size)
    input_data = torch.ones((batch_size, input_size))
    value_out = value_heads(input_data)

    for stream_name in stream_names:
        assert value_out[stream_name].shape == (batch_size, output_size)
Exemplo n.º 2
0
 def __init__(
     self,
     sensor_specs: List[SensorSpec],
     network_settings: NetworkSettings,
     action_spec: ActionSpec,
     stream_names: List[str],
     conditional_sigma: bool = False,
     tanh_squash: bool = False,
 ):
     self.use_lstm = network_settings.memory is not None
     super().__init__(sensor_specs, network_settings, action_spec,
                      conditional_sigma, tanh_squash)
     self.stream_names = stream_names
     self.value_heads = ValueHeads(stream_names, self.encoding_size)
Exemplo n.º 3
0
        def __init__(
            self,
            stream_names: List[str],
            observation_specs: List[ObservationSpec],
            network_settings: NetworkSettings,
            action_spec: ActionSpec,
        ):
            torch.nn.Module.__init__(self)
            self.network_body = MultiAgentNetworkBody(observation_specs,
                                                      network_settings,
                                                      action_spec)
            if network_settings.memory is not None:
                encoding_size = network_settings.memory.memory_size // 2
            else:
                encoding_size = network_settings.hidden_units

            self.value_heads = ValueHeads(stream_names, encoding_size, 1)
Exemplo n.º 4
0
 def __init__(
     self,
     observation_shapes: List[Tuple[int, ...]],
     network_settings: NetworkSettings,
     action_spec: ActionSpec,
     stream_names: List[str],
     conditional_sigma: bool = False,
     tanh_squash: bool = False,
 ):
     super().__init__(
         observation_shapes,
         network_settings,
         action_spec,
         conditional_sigma,
         tanh_squash,
     )
     self.stream_names = stream_names
     self.value_heads = ValueHeads(stream_names, self.encoding_size)
Exemplo n.º 5
0
    def __init__(
        self,
        stream_names: List[str],
        observation_shapes: List[Tuple[int, ...]],
        network_settings: NetworkSettings,
        encoded_act_size: int = 0,
        outputs_per_stream: int = 1,
    ):

        # This is not a typo, we want to call __init__ of nn.Module
        nn.Module.__init__(self)
        self.network_body = NetworkBody(
            observation_shapes, network_settings, encoded_act_size=encoded_act_size
        )
        if network_settings.memory is not None:
            encoding_size = network_settings.memory.memory_size // 2
        else:
            encoding_size = network_settings.hidden_units
        self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)