def __init__( self, observation_shapes: List[Tuple[int, ...]], network_settings: NetworkSettings, encoded_act_size: int = 0, ): super().__init__() self.normalize = network_settings.normalize self.use_lstm = network_settings.memory is not None self.h_size = network_settings.hidden_units self.m_size = (network_settings.memory.memory_size if network_settings.memory is not None else 0) self.visual_processors, self.vector_processors, encoder_input_size = ModelUtils.create_input_processors( observation_shapes, self.h_size, network_settings.vis_encode_type, normalize=self.normalize, ) total_enc_size = encoder_input_size + encoded_act_size self.linear_encoder = LinearEncoder(total_enc_size, network_settings.num_layers, self.h_size) if self.use_lstm: self.lstm = LSTM(self.h_size, self.m_size) else: self.lstm = None # type: ignore
def __init__( self, sensor_specs: List[SensorSpec], network_settings: NetworkSettings, encoded_act_size: int = 0, ): super().__init__() self.normalize = network_settings.normalize self.use_lstm = network_settings.memory is not None self.h_size = network_settings.hidden_units self.m_size = (network_settings.memory.memory_size if network_settings.memory is not None else 0) self.processors, self.embedding_sizes = ModelUtils.create_input_processors( sensor_specs, self.h_size, network_settings.vis_encode_type, normalize=self.normalize, ) total_enc_size = sum(self.embedding_sizes) + encoded_act_size self.linear_encoder = LinearEncoder(total_enc_size, network_settings.num_layers, self.h_size) if self.use_lstm: self.lstm = LSTM(self.h_size, self.m_size) else: self.lstm = None # type: ignore
def __init__( self, observation_specs: List[ObservationSpec], network_settings: NetworkSettings, encoded_act_size: int = 0, ): super().__init__() self.normalize = network_settings.normalize self.use_lstm = network_settings.memory is not None self.h_size = network_settings.hidden_units self.m_size = (network_settings.memory.memory_size if network_settings.memory is not None else 0) self.observation_encoder = ObservationEncoder( observation_specs, self.h_size, network_settings.vis_encode_type, self.normalize, ) self.processors = self.observation_encoder.processors total_enc_size = self.observation_encoder.total_enc_size total_enc_size += encoded_act_size self.linear_encoder = LinearEncoder(total_enc_size, network_settings.num_layers, self.h_size) if self.use_lstm: self.lstm = LSTM(self.h_size, self.m_size) else: self.lstm = None # type: ignore
def __init__( self, observation_specs: List[ObservationSpec], network_settings: NetworkSettings, encoded_act_size: int = 0, ): super().__init__() self.normalize = network_settings.normalize self.use_lstm = network_settings.memory is not None self.h_size = network_settings.hidden_units self.m_size = (network_settings.memory.memory_size if network_settings.memory is not None else 0) self.processors, self.embedding_sizes = ModelUtils.create_input_processors( observation_specs, self.h_size, network_settings.vis_encode_type, normalize=self.normalize, ) entity_num_max: int = 0 var_processors = [ p for p in self.processors if isinstance(p, EntityEmbedding) ] for processor in var_processors: entity_max: int = processor.entity_num_max_elements # Only adds entity max if it was known at construction if entity_max > 0: entity_num_max += entity_max if len(var_processors) > 0: if sum(self.embedding_sizes): self.x_self_encoder = LinearEncoder( sum(self.embedding_sizes), 1, self.h_size, kernel_init=Initialization.Normal, kernel_gain=(0.125 / self.h_size)**0.5, ) self.rsa = ResidualSelfAttention(self.h_size, entity_num_max) total_enc_size = sum(self.embedding_sizes) + self.h_size else: total_enc_size = sum(self.embedding_sizes) total_enc_size += encoded_act_size self.linear_encoder = LinearEncoder(total_enc_size, network_settings.num_layers, self.h_size) if self.use_lstm: self.lstm = LSTM(self.h_size, self.m_size) else: self.lstm = None # type: ignore
def __init__( self, observation_specs: List[ObservationSpec], network_settings: NetworkSettings, action_spec: ActionSpec, ): super().__init__() self.normalize = network_settings.normalize self.use_lstm = network_settings.memory is not None self.h_size = network_settings.hidden_units self.m_size = (network_settings.memory.memory_size if network_settings.memory is not None else 0) self.action_spec = action_spec self.observation_encoder = ObservationEncoder( observation_specs, self.h_size, network_settings.vis_encode_type, self.normalize, ) self.processors = self.observation_encoder.processors # Modules for multi-agent self-attention obs_only_ent_size = self.observation_encoder.total_enc_size q_ent_size = (obs_only_ent_size + sum(self.action_spec.discrete_branches) + self.action_spec.continuous_size) attention_embeding_size = self.h_size self.obs_encoder = EntityEmbedding(obs_only_ent_size, None, attention_embeding_size) self.obs_action_encoder = EntityEmbedding(q_ent_size, None, attention_embeding_size) self.self_attn = ResidualSelfAttention(attention_embeding_size) self.linear_encoder = LinearEncoder( attention_embeding_size, network_settings.num_layers, self.h_size, kernel_gain=(0.125 / self.h_size)**0.5, ) if self.use_lstm: self.lstm = LSTM(self.h_size, self.m_size) else: self.lstm = None # type: ignore self._current_max_agents = torch.nn.Parameter(torch.as_tensor(1), requires_grad=False)
def test_lstm_class(): torch.manual_seed(0) input_size = 12 memory_size = 64 batch_size = 8 seq_len = 16 lstm = LSTM(input_size, memory_size) assert lstm.memory_size == memory_size sample_input = torch.ones((batch_size, seq_len, input_size)) sample_memories = torch.ones((1, batch_size, memory_size)) out, mem = lstm(sample_input, sample_memories) # Hidden size should be half of memory_size assert out.shape == (batch_size, seq_len, memory_size // 2) assert mem.shape == (1, batch_size, memory_size)