Esempio n. 1
0
    def forward(
        self,
        contextualized_input_batch: Tensor,
        stacks: Tensor,
        buffers: Tensor,
        stack_lengths: Tensor,
        buffer_lengths: Tensor,
        padding: Optional[Tensor] = None,
        sentence_features: Optional[torch.Tensor] = None,
        sentence_ids: Optional[List[str]] = None,
        **kwargs,
    ) -> Tensor:
        stack_batch = lookup_tensors_for_indices(stacks,
                                                 contextualized_input_batch)
        buffer_batch = lookup_tensors_for_indices(buffers,
                                                  contextualized_input_batch)

        stack_batch = self.positional_encoder(stack_batch)
        buffer_batch = self.positional_encoder(buffer_batch)

        # Compute a representation of the stack / buffer as an weighted average based
        # on the attention weights.
        stack_batch_attention, _, stack_attention_energies = self.stack_attention(
            keys=stack_batch, sequence_lengths=stack_lengths)

        buffer_batch_attention, _, buffer_attention_energies = self.buffer_attention(
            keys=buffer_batch, sequence_lengths=buffer_lengths)

        if self.reporter:
            self.reporter.log(
                "buffer",
                buffers,
                buffer_lengths,
                buffer_attention_energies,
                sentence_features,
                sentence_ids,
            )
            self.reporter.log(
                "stack",
                stacks,
                stack_lengths,
                stack_attention_energies,
                sentence_features,
                sentence_ids,
            )

        return torch.cat((stack_batch_attention, buffer_batch_attention),
                         dim=1)
Esempio n. 2
0
    def forward(
        self,
        contextualized_input_batch: Tensor,
        stacks: Tensor,
        buffers: Tensor,
        stack_lengths: Tensor,
        buffer_lengths: Tensor,
        finished_tokens: Tensor,
        finished_tokens_lengths: Tensor,
        sentence_lengths: Tensor,
        padding: Optional[Tensor] = None,
        sentence_features: Optional[torch.Tensor] = None,
        sentence_ids: Optional[List[str]] = None,
        **kwargs,
    ) -> Tensor:
        # Look-up the whole unpadded buffer and stack sequence
        stack_keys = lookup_tensors_for_indices(stacks, contextualized_input_batch)
        buffer_keys = lookup_tensors_for_indices(buffers, contextualized_input_batch)
        finished_tokens_keys = lookup_tensors_for_indices(
            finished_tokens, contextualized_input_batch
        )

        # Add a padding vector so that even empty sequences have at least one item
        stack_keys = self._fix_empty_sequence(stack_keys)
        buffer_keys = self._fix_empty_sequence(buffer_keys)
        finished_tokens_keys = self._fix_empty_sequence(finished_tokens_keys)

        # Add positional encoding
        stack_keys = self.positional_encoder(stack_keys)
        buffer_keys = self.positional_encoder(buffer_keys)
        finished_tokens_keys = self.positional_encoder(finished_tokens_keys)
        sentence_keys = self.positional_encoder(contextualized_input_batch)

        # Run universal attention over the finished tokens
        tokens_attention, _, tokens_attention_energies = self.finished_tokens_attention(
            keys=finished_tokens_keys, sequence_lengths=finished_tokens_lengths
        )

        sentence_attention, _, sentence_attention_energies = self.sentence_attention(
            queries=tokens_attention, keys=sentence_keys, sequence_lengths=sentence_lengths
        )

        # Compute a representation of the stack / buffer as an weighted average based
        # on the attention weights.
        stack_batch_attention, _, stack_attention_energies = self.stack_attention(
            queries=sentence_attention, keys=stack_keys, sequence_lengths=stack_lengths
        )

        buffer_batch_attention, _, buffer_attention_energies = self.buffer_attention(
            queries=sentence_attention, keys=buffer_keys, sequence_lengths=buffer_lengths
        )

        if self.reporter:
            sentence_tokens = torch.arange(
                contextualized_input_batch.size(1), device=self.device
            ).expand(contextualized_input_batch.size(0), contextualized_input_batch.size(1))
            self.reporter.log(
                "finished_tokens",
                finished_tokens,
                finished_tokens_lengths,
                tokens_attention_energies,
                sentence_features,
                sentence_ids,
            )
            self.reporter.log(
                "sentence",
                sentence_tokens,
                sentence_lengths,
                sentence_attention_energies,
                sentence_features,
                sentence_ids,
            )
            self.reporter.log(
                "stack",
                stacks,
                stack_lengths,
                stack_attention_energies,
                sentence_features,
                sentence_ids,
            )
            self.reporter.log(
                "buffer",
                buffers,
                buffer_lengths,
                buffer_attention_energies,
                sentence_features,
                sentence_ids,
            )

        return torch.cat((stack_batch_attention, buffer_batch_attention), dim=1)
Esempio n. 3
0
    def forward(
        self,
        contextualized_input_batch: Tensor,
        stacks: Tensor,
        buffers: Tensor,
        stack_lengths: Tensor,
        buffer_lengths: Tensor,
        padding: Optional[Tensor] = None,
        sentence_features: Optional[torch.Tensor] = None,
        sentence_ids: Optional[List[str]] = None,
        **kwargs,
    ) -> Tensor:
        # Look-up the whole unpadded buffer and stack sequence
        stack_keys = lookup_tensors_for_indices(stacks, contextualized_input_batch)
        buffer_keys = lookup_tensors_for_indices(buffers, contextualized_input_batch)

        stack_keys = self._fix_empty_sequence(stack_keys)
        buffer_keys = self._fix_empty_sequence(buffer_keys)

        # Add positional encoding
        stack_keys = self.positional_encoder(stack_keys)
        buffer_keys = self.positional_encoder(buffer_keys)

        # Take the first entry as query
        stack_queries = buffer_keys.index_select(
            dim=1, index=torch.zeros(1, dtype=torch.int64, device=self.device)
        ).squeeze(1)

        buffer_queries = stack_keys.index_select(
            dim=1, index=torch.zeros(1, dtype=torch.int64, device=self.device)
        ).squeeze(1)

        # Compute a representation of the stack / buffer as an weighted average based
        # on the attention weights.
        stack_batch_attention, _, stack_attention_energies = self.stack_attention(
            queries=stack_queries, keys=stack_keys, sequence_lengths=stack_lengths
        )

        buffer_batch_attention, _, buffer_attention_energies = self.buffer_attention(
            queries=buffer_queries, keys=buffer_keys, sequence_lengths=buffer_lengths
        )

        if self.reporter:
            self.reporter.log(
                "stack",
                stacks,
                stack_lengths,
                stack_attention_energies,
                sentence_features,
                sentence_ids,
            )
            self.reporter.log(
                "buffer",
                buffers,
                buffer_lengths,
                buffer_attention_energies,
                sentence_features,
                sentence_ids,
            )

        return torch.cat((stack_batch_attention, buffer_batch_attention), dim=1)