Esempio n. 1
0
def test_zero_mask_layer():
    batch_size, size = 10, 30

    def generate_input_helper(pattern):
        _input = torch.zeros((batch_size, 0, size))
        for i in range(len(pattern)):
            if i % 2 == 0:
                _input = torch.cat(
                    [_input,
                     torch.rand((batch_size, pattern[i], size))],
                    dim=1)
            else:
                _input = torch.cat(
                    [_input,
                     torch.zeros((batch_size, pattern[i], size))],
                    dim=1)
        return _input

    masking_pattern_1 = [3, 2, 3, 4]
    masking_pattern_2 = [5, 7, 8, 2]
    input_1 = generate_input_helper(masking_pattern_1)
    input_2 = generate_input_helper(masking_pattern_2)

    masks = get_zero_entities_mask([input_1, input_2])
    assert len(masks) == 2
    masks_1 = masks[0]
    masks_2 = masks[1]
    assert masks_1.shape == (batch_size, sum(masking_pattern_1))
    assert masks_2.shape == (batch_size, sum(masking_pattern_2))
    for i in masking_pattern_1:
        assert masks_1[0, 1] == 0 if i % 2 == 0 else 1
    for i in masking_pattern_2:
        assert masks_2[0, 1] == 0 if i % 2 == 0 else 1
Esempio n. 2
0
    def forward(
        self,
        inputs: List[torch.Tensor],
        actions: Optional[torch.Tensor] = None,
        memories: Optional[torch.Tensor] = None,
        sequence_length: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        encodes = []
        var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []

        for idx, processor in enumerate(self.processors):
            if not isinstance(processor, EntityEmbedding):
                # The input can be encoded without having to process other inputs
                obs_input = inputs[idx]
                processed_obs = processor(obs_input)
                encodes.append(processed_obs)
            else:
                var_len_processor_inputs.append((processor, inputs[idx]))
        if len(encodes) != 0:
            encoded_self = torch.cat(encodes, dim=1)
            input_exist = True
        else:
            input_exist = False
        if len(var_len_processor_inputs) > 0:
            # Some inputs need to be processed with a variable length encoder
            masks = get_zero_entities_mask(
                [p_i[1] for p_i in var_len_processor_inputs])
            embeddings: List[torch.Tensor] = []
            processed_self = self.x_self_encoder(
                encoded_self) if input_exist else None
            for processor, var_len_input in var_len_processor_inputs:
                embeddings.append(processor(processed_self, var_len_input))
            qkv = torch.cat(embeddings, dim=1)
            attention_embedding = self.rsa(qkv, masks)
            if not input_exist:
                encoded_self = torch.cat([attention_embedding], dim=1)
                input_exist = True
            else:
                encoded_self = torch.cat([encoded_self, attention_embedding],
                                         dim=1)

        if not input_exist:
            raise Exception(
                "The trainer was unable to process any of the provided inputs. "
                "Make sure the trained agents has at least one sensor attached to them."
            )

        if actions is not None:
            encoded_self = torch.cat([encoded_self, actions], dim=1)
        encoding = self.linear_encoder(encoded_self)

        if self.use_lstm:
            # Resize to (batch, sequence length, encoding size)
            encoding = encoding.reshape([-1, sequence_length, self.h_size])
            encoding, memories = self.lstm(encoding, memories)
            encoding = encoding.reshape([-1, self.m_size // 2])
        return encoding, memories
Esempio n. 3
0
    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
        """
        Encode observations using a list of processors and an RSA.
        :param inputs: List of Tensors corresponding to a set of obs.
        :param processors: a ModuleList of the input processors to be applied to these obs.
        :param rsa: Optionally, an RSA to use for variable length obs.
        :param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
        """
        encodes = []
        var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []

        for idx, processor in enumerate(self.processors):
            if not isinstance(processor, EntityEmbedding):
                # The input can be encoded without having to process other inputs
                obs_input = inputs[idx]
                processed_obs = processor(obs_input)
                encodes.append(processed_obs)
            else:
                var_len_processor_inputs.append((processor, inputs[idx]))
        if len(encodes) != 0:
            encoded_self = torch.cat(encodes, dim=1)
            input_exist = True
        else:
            input_exist = False
        if len(var_len_processor_inputs) > 0 and self.rsa is not None:
            # Some inputs need to be processed with a variable length encoder
            masks = get_zero_entities_mask(
                [p_i[1] for p_i in var_len_processor_inputs])
            embeddings: List[torch.Tensor] = []
            processed_self = (self.x_self_encoder(encoded_self) if input_exist
                              and self.x_self_encoder is not None else None)
            for processor, var_len_input in var_len_processor_inputs:
                embeddings.append(processor(processed_self, var_len_input))
            qkv = torch.cat(embeddings, dim=1)
            attention_embedding = self.rsa(qkv, masks)
            if not input_exist:
                encoded_self = torch.cat([attention_embedding], dim=1)
                input_exist = True
            else:
                encoded_self = torch.cat([encoded_self, attention_embedding],
                                         dim=1)

        if not input_exist:
            raise UnityTrainerException(
                "The trainer was unable to process any of the provided inputs. "
                "Make sure the trained agents has at least one sensor attached to them."
            )

        return encoded_self
Esempio n. 4
0
def test_predict_minimum_training():
    # of 5 numbers, predict index of min
    np.random.seed(1336)
    torch.manual_seed(1336)
    n_k = 5
    size = n_k + 1
    embedding_size = 64
    entity_embedding = EntityEmbedding(size, n_k, embedding_size)  # no self
    transformer = ResidualSelfAttention(embedding_size)
    l_layer = LinearEncoder(embedding_size, 2, n_k)
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        list(entity_embedding.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )

    batch_size = 200
    onehots = ModelUtils.actions_to_onehot(
        torch.range(0, n_k - 1).unsqueeze(1), [n_k])[0]
    onehots = onehots.expand((batch_size, -1, -1))
    losses = []
    for _ in range(400):
        num = np.random.randint(0, n_k)
        inp = torch.rand((batch_size, num + 1, 1))
        with torch.no_grad():
            # create the target : The minimum
            argmin = torch.argmin(inp, dim=1)
            argmin = argmin.squeeze()
            argmin = argmin.detach()
        sliced_oh = onehots[:, :num + 1]
        inp = torch.cat([inp, sliced_oh], dim=2)

        embeddings = entity_embedding(inp, inp)
        masks = get_zero_entities_mask([inp])
        prediction = transformer(embeddings, masks)
        prediction = l_layer(prediction)
        ce = loss(prediction, argmin)
        losses.append(ce.item())
        print(ce.item())
        optimizer.zero_grad()
        ce.backward()
        optimizer.step()
    assert np.array(losses[-20:]).mean() < 0.1
Esempio n. 5
0
def test_predict_closest_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 200
    for _ in range(200):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = get_zero_entities_mask([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.02