Пример #1
0
    def forward(
        self,
        encoded_paths: torch.Tensor,
        contexts_per_label: List[int],
    ) -> torch.Tensor:
        """Classify given paths

        :param encoded_paths: [n paths; classifier size]
        :param contexts_per_label: [n1, n2, ..., nk] sum = n paths
        :return:
        """
        # [batch size; max context size; classifier input size], [batch size; max context size]
        batched_context, attention_mask = cut_encoded_contexts(
            encoded_paths, contexts_per_label, self._negative_value)

        # [batch size; max context size; 1]
        attn_weights = self.attention(batched_context, attention_mask)

        # [batch size; classifier input size]
        context = torch.bmm(attn_weights.transpose(1, 2),
                            batched_context).squeeze(1)

        # [batch size; hidden size]
        hidden = self.hidden_layers(context)

        # [batch size; num classes]
        output = self.classification_layer(hidden)
        return output
Пример #2
0
    def forward(
        self,
        encoded_paths: torch.Tensor,
        contexts_per_label: List[int],
        output_length: int,
        target_sequence: torch.Tensor = None,
    ) -> torch.Tensor:
        """Decode given paths into sequence

        :param encoded_paths: [n paths; decoder size]
        :param contexts_per_label: [n1, n2, ..., nk] sum = n paths
        :param output_length: length of output sequence
        :param target_sequence: [sequence length; batch size]
        :return:
        """
        batch_size = len(contexts_per_label)
        # [batch size; max context size; decoder size], [batch size; max context size]
        batched_context, attention_mask = cut_encoded_contexts(
            encoded_paths, contexts_per_label, self._negative_value)

        # [n layers; batch size; decoder size]
        initial_state = (torch.cat([
            ctx_batch.mean(0).unsqueeze(0)
            for ctx_batch in encoded_paths.split(contexts_per_label)
        ]).unsqueeze(0).repeat(self.num_decoder_layers, 1, 1))
        h_prev, c_prev = initial_state, initial_state

        # [target len; batch size; vocab size]
        output = encoded_paths.new_zeros(
            (output_length, batch_size, self.out_size))
        # [batch size]
        current_input = encoded_paths.new_full((batch_size, ),
                                               self.sos_token,
                                               dtype=torch.long)
        for step in range(output_length):
            current_output, (h_prev, c_prev) = self.decoder_step(
                current_input, h_prev, c_prev, batched_context, attention_mask)
            output[step] = current_output
            if target_sequence is not None and torch.rand(
                    1) < self.teacher_forcing:
                current_input = target_sequence[step]
            else:
                current_input = output[step].argmax(dim=-1)

        return output
Пример #3
0
    def test_cut_encoded_contexts(self):
        units = 10
        mask_value = -1
        batch_size = 5
        contexts_per_label = list(range(1, batch_size + 1))
        max_context_len = max(contexts_per_label)

        encoded_contexts = torch.cat([
            torch.full((i, units), i, dtype=torch.float)
            for i in contexts_per_label
        ])

        def create_true_batch(fill_value: int, counts: int,
                              size: int) -> torch.tensor:
            return torch.cat(
                [
                    torch.full(
                        (1, counts, units), fill_value, dtype=torch.float),
                    torch.zeros((1, size - counts, units))
                ],
                dim=1,
            )

        def create_batch_mask(counts: int, size: int) -> torch.tensor:
            return torch.cat([
                torch.zeros(1, counts),
                torch.full((1, size - counts), mask_value, dtype=torch.float)
            ],
                             dim=1)

        true_batched_context = torch.cat([
            create_true_batch(i, i, max_context_len)
            for i in contexts_per_label
        ])
        true_attention_mask = torch.cat([
            create_batch_mask(i, max_context_len) for i in contexts_per_label
        ])

        batched_context, attention_mask = cut_encoded_contexts(
            encoded_contexts, contexts_per_label, mask_value)

        torch.testing.assert_allclose(batched_context, true_batched_context)
        torch.testing.assert_allclose(attention_mask, true_attention_mask)