예제 #1
0
    def __init__(
        self,
        observation_shapes: List[Tuple[int, ...]],
        network_settings: NetworkSettings,
        encoded_act_size: int = 0,
    ):
        super().__init__()
        self.normalize = network_settings.normalize
        self.use_lstm = network_settings.memory is not None
        self.h_size = network_settings.hidden_units
        self.m_size = (network_settings.memory.memory_size
                       if network_settings.memory is not None else 0)

        self.visual_processors, self.vector_processors, encoder_input_size = ModelUtils.create_input_processors(
            observation_shapes,
            self.h_size,
            network_settings.vis_encode_type,
            normalize=self.normalize,
        )
        total_enc_size = encoder_input_size + encoded_act_size
        self.linear_encoder = LinearEncoder(total_enc_size,
                                            network_settings.num_layers,
                                            self.h_size)

        if self.use_lstm:
            self.lstm = LSTM(self.h_size, self.m_size)
        else:
            self.lstm = None  # type: ignore
예제 #2
0
    def __init__(
        self,
        sensor_specs: List[SensorSpec],
        network_settings: NetworkSettings,
        encoded_act_size: int = 0,
    ):
        super().__init__()
        self.normalize = network_settings.normalize
        self.use_lstm = network_settings.memory is not None
        self.h_size = network_settings.hidden_units
        self.m_size = (network_settings.memory.memory_size
                       if network_settings.memory is not None else 0)

        self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
            sensor_specs,
            self.h_size,
            network_settings.vis_encode_type,
            normalize=self.normalize,
        )

        total_enc_size = sum(self.embedding_sizes) + encoded_act_size
        self.linear_encoder = LinearEncoder(total_enc_size,
                                            network_settings.num_layers,
                                            self.h_size)

        if self.use_lstm:
            self.lstm = LSTM(self.h_size, self.m_size)
        else:
            self.lstm = None  # type: ignore
예제 #3
0
    def __init__(
        self,
        observation_specs: List[ObservationSpec],
        network_settings: NetworkSettings,
        encoded_act_size: int = 0,
    ):
        super().__init__()
        self.normalize = network_settings.normalize
        self.use_lstm = network_settings.memory is not None
        self.h_size = network_settings.hidden_units
        self.m_size = (network_settings.memory.memory_size
                       if network_settings.memory is not None else 0)
        self.observation_encoder = ObservationEncoder(
            observation_specs,
            self.h_size,
            network_settings.vis_encode_type,
            self.normalize,
        )
        self.processors = self.observation_encoder.processors
        total_enc_size = self.observation_encoder.total_enc_size
        total_enc_size += encoded_act_size
        self.linear_encoder = LinearEncoder(total_enc_size,
                                            network_settings.num_layers,
                                            self.h_size)

        if self.use_lstm:
            self.lstm = LSTM(self.h_size, self.m_size)
        else:
            self.lstm = None  # type: ignore
예제 #4
0
    def __init__(
        self,
        observation_specs: List[ObservationSpec],
        network_settings: NetworkSettings,
        encoded_act_size: int = 0,
    ):
        super().__init__()
        self.normalize = network_settings.normalize
        self.use_lstm = network_settings.memory is not None
        self.h_size = network_settings.hidden_units
        self.m_size = (network_settings.memory.memory_size
                       if network_settings.memory is not None else 0)

        self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
            observation_specs,
            self.h_size,
            network_settings.vis_encode_type,
            normalize=self.normalize,
        )

        entity_num_max: int = 0
        var_processors = [
            p for p in self.processors if isinstance(p, EntityEmbedding)
        ]
        for processor in var_processors:
            entity_max: int = processor.entity_num_max_elements
            # Only adds entity max if it was known at construction
            if entity_max > 0:
                entity_num_max += entity_max
        if len(var_processors) > 0:
            if sum(self.embedding_sizes):
                self.x_self_encoder = LinearEncoder(
                    sum(self.embedding_sizes),
                    1,
                    self.h_size,
                    kernel_init=Initialization.Normal,
                    kernel_gain=(0.125 / self.h_size)**0.5,
                )
            self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)
            total_enc_size = sum(self.embedding_sizes) + self.h_size
        else:
            total_enc_size = sum(self.embedding_sizes)

        total_enc_size += encoded_act_size
        self.linear_encoder = LinearEncoder(total_enc_size,
                                            network_settings.num_layers,
                                            self.h_size)

        if self.use_lstm:
            self.lstm = LSTM(self.h_size, self.m_size)
        else:
            self.lstm = None  # type: ignore
예제 #5
0
    def __init__(
        self,
        observation_specs: List[ObservationSpec],
        network_settings: NetworkSettings,
        action_spec: ActionSpec,
    ):
        super().__init__()
        self.normalize = network_settings.normalize
        self.use_lstm = network_settings.memory is not None
        self.h_size = network_settings.hidden_units
        self.m_size = (network_settings.memory.memory_size
                       if network_settings.memory is not None else 0)
        self.action_spec = action_spec
        self.observation_encoder = ObservationEncoder(
            observation_specs,
            self.h_size,
            network_settings.vis_encode_type,
            self.normalize,
        )
        self.processors = self.observation_encoder.processors

        # Modules for multi-agent self-attention
        obs_only_ent_size = self.observation_encoder.total_enc_size
        q_ent_size = (obs_only_ent_size +
                      sum(self.action_spec.discrete_branches) +
                      self.action_spec.continuous_size)

        attention_embeding_size = self.h_size
        self.obs_encoder = EntityEmbedding(obs_only_ent_size, None,
                                           attention_embeding_size)
        self.obs_action_encoder = EntityEmbedding(q_ent_size, None,
                                                  attention_embeding_size)

        self.self_attn = ResidualSelfAttention(attention_embeding_size)

        self.linear_encoder = LinearEncoder(
            attention_embeding_size,
            network_settings.num_layers,
            self.h_size,
            kernel_gain=(0.125 / self.h_size)**0.5,
        )

        if self.use_lstm:
            self.lstm = LSTM(self.h_size, self.m_size)
        else:
            self.lstm = None  # type: ignore
        self._current_max_agents = torch.nn.Parameter(torch.as_tensor(1),
                                                      requires_grad=False)
예제 #6
0
def test_lstm_class():
    torch.manual_seed(0)
    input_size = 12
    memory_size = 64
    batch_size = 8
    seq_len = 16
    lstm = LSTM(input_size, memory_size)

    assert lstm.memory_size == memory_size

    sample_input = torch.ones((batch_size, seq_len, input_size))
    sample_memories = torch.ones((1, batch_size, memory_size))
    out, mem = lstm(sample_input, sample_memories)
    # Hidden size should be half of memory_size
    assert out.shape == (batch_size, seq_len, memory_size // 2)
    assert mem.shape == (1, batch_size, memory_size)