def test_valueheads(): stream_names = [f"reward_signal_{num}" for num in range(5)] input_size = 5 batch_size = 4 # Test default 1 value per head value_heads = ValueHeads(stream_names, input_size) input_data = torch.ones((batch_size, input_size)) value_out = value_heads( input_data) # Note: mean value will be removed shortly for stream_name in stream_names: assert value_out[stream_name].shape == (batch_size, ) # Test that inputting the wrong size input will throw an error with pytest.raises(Exception): value_out = value_heads(torch.ones((batch_size, input_size + 2))) # Test multiple values per head (e.g. discrete Q function) output_size = 4 value_heads = ValueHeads(stream_names, input_size, output_size) input_data = torch.ones((batch_size, input_size)) value_out = value_heads(input_data) for stream_name in stream_names: assert value_out[stream_name].shape == (batch_size, output_size)
def __init__( self, sensor_specs: List[SensorSpec], network_settings: NetworkSettings, action_spec: ActionSpec, stream_names: List[str], conditional_sigma: bool = False, tanh_squash: bool = False, ): self.use_lstm = network_settings.memory is not None super().__init__(sensor_specs, network_settings, action_spec, conditional_sigma, tanh_squash) self.stream_names = stream_names self.value_heads = ValueHeads(stream_names, self.encoding_size)
def __init__( self, stream_names: List[str], observation_specs: List[ObservationSpec], network_settings: NetworkSettings, action_spec: ActionSpec, ): torch.nn.Module.__init__(self) self.network_body = MultiAgentNetworkBody(observation_specs, network_settings, action_spec) if network_settings.memory is not None: encoding_size = network_settings.memory.memory_size // 2 else: encoding_size = network_settings.hidden_units self.value_heads = ValueHeads(stream_names, encoding_size, 1)
def __init__( self, observation_shapes: List[Tuple[int, ...]], network_settings: NetworkSettings, action_spec: ActionSpec, stream_names: List[str], conditional_sigma: bool = False, tanh_squash: bool = False, ): super().__init__( observation_shapes, network_settings, action_spec, conditional_sigma, tanh_squash, ) self.stream_names = stream_names self.value_heads = ValueHeads(stream_names, self.encoding_size)
def __init__( self, stream_names: List[str], observation_shapes: List[Tuple[int, ...]], network_settings: NetworkSettings, encoded_act_size: int = 0, outputs_per_stream: int = 1, ): # This is not a typo, we want to call __init__ of nn.Module nn.Module.__init__(self) self.network_body = NetworkBody( observation_shapes, network_settings, encoded_act_size=encoded_act_size ) if network_settings.memory is not None: encoding_size = network_settings.memory.memory_size // 2 else: encoding_size = network_settings.hidden_units self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)