예제 #1
0
    def __init__(
        self,
        observation_specs: List[ObservationSpec],
        network_settings: NetworkSettings,
        action_spec: ActionSpec,
    ):
        super().__init__()
        self.normalize = network_settings.normalize
        self.use_lstm = network_settings.memory is not None
        self.h_size = network_settings.hidden_units
        self.m_size = (network_settings.memory.memory_size
                       if network_settings.memory is not None else 0)
        self.action_spec = action_spec
        self.observation_encoder = ObservationEncoder(
            observation_specs,
            self.h_size,
            network_settings.vis_encode_type,
            self.normalize,
        )
        self.processors = self.observation_encoder.processors

        # Modules for multi-agent self-attention
        obs_only_ent_size = self.observation_encoder.total_enc_size
        q_ent_size = (obs_only_ent_size +
                      sum(self.action_spec.discrete_branches) +
                      self.action_spec.continuous_size)

        attention_embeding_size = self.h_size
        self.obs_encoder = EntityEmbedding(obs_only_ent_size, None,
                                           attention_embeding_size)
        self.obs_action_encoder = EntityEmbedding(q_ent_size, None,
                                                  attention_embeding_size)

        self.self_attn = ResidualSelfAttention(attention_embeding_size)

        self.linear_encoder = LinearEncoder(
            attention_embeding_size,
            network_settings.num_layers,
            self.h_size,
            kernel_gain=(0.125 / self.h_size)**0.5,
        )

        if self.use_lstm:
            self.lstm = LSTM(self.h_size, self.m_size)
        else:
            self.lstm = None  # type: ignore
        self._current_max_agents = torch.nn.Parameter(torch.as_tensor(1),
                                                      requires_grad=False)
예제 #2
0
def test_predict_minimum_training():
    # of 5 numbers, predict index of min
    np.random.seed(1336)
    torch.manual_seed(1336)
    n_k = 5
    size = n_k + 1
    embedding_size = 64
    entity_embedding = EntityEmbedding(size, n_k, embedding_size)  # no self
    transformer = ResidualSelfAttention(embedding_size)
    l_layer = LinearEncoder(embedding_size, 2, n_k)
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        list(entity_embedding.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )

    batch_size = 200
    onehots = ModelUtils.actions_to_onehot(
        torch.range(0, n_k - 1).unsqueeze(1), [n_k])[0]
    onehots = onehots.expand((batch_size, -1, -1))
    losses = []
    for _ in range(400):
        num = np.random.randint(0, n_k)
        inp = torch.rand((batch_size, num + 1, 1))
        with torch.no_grad():
            # create the target : The minimum
            argmin = torch.argmin(inp, dim=1)
            argmin = argmin.squeeze()
            argmin = argmin.detach()
        sliced_oh = onehots[:, :num + 1]
        inp = torch.cat([inp, sliced_oh], dim=2)

        embeddings = entity_embedding(inp, inp)
        masks = get_zero_entities_mask([inp])
        prediction = transformer(embeddings, masks)
        prediction = l_layer(prediction)
        ce = loss(prediction, argmin)
        losses.append(ce.item())
        print(ce.item())
        optimizer.zero_grad()
        ce.backward()
        optimizer.step()
    assert np.array(losses[-20:]).mean() < 0.1
예제 #3
0
def test_predict_closest_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 200
    for _ in range(200):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = get_zero_entities_mask([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.02
예제 #4
0
def test_all_masking(mask_value):
    # We make sure that a mask of all zeros or all ones will not trigger an error
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 20
    for _ in range(5):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = [torch.ones_like(key[:, :, 0]) * mask_value]
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
예제 #5
0
    def get_encoder_for_obs(
        obs_spec: ObservationSpec,
        normalize: bool,
        h_size: int,
        attention_embedding_size: int,
        vis_encode_type: EncoderType,
    ) -> Tuple[nn.Module, int]:
        """
        Returns the encoder and the size of the appropriate encoder.
        :param shape: Tuples that represent the observation dimension.
        :param normalize: Normalize all vector inputs.
        :param h_size: Number of hidden units per layer excluding attention layers.
        :param attention_embedding_size: Number of hidden units per attention layer.
        :param vis_encode_type: Type of visual encoder to use.
        """
        shape = obs_spec.shape
        dim_prop = obs_spec.dimension_property

        # VISUAL
        if dim_prop in ModelUtils.VALID_VISUAL_PROP:
            visual_encoder_class = ModelUtils.get_encoder_for_type(
                vis_encode_type)
            return (visual_encoder_class(shape[0], shape[1], shape[2],
                                         h_size), h_size)
        # VECTOR
        if dim_prop in ModelUtils.VALID_VECTOR_PROP:
            return (VectorInput(shape[0], normalize), shape[0])
        # VARIABLE LENGTH
        if dim_prop in ModelUtils.VALID_VAR_LEN_PROP:
            return (
                EntityEmbedding(
                    entity_size=shape[1],
                    entity_num_max_elements=shape[0],
                    embedding_size=attention_embedding_size,
                ),
                0,
            )
        # OTHER
        raise UnityTrainerException(
            f"Unsupported Sensor with specs {obs_spec}")