コード例 #1
0
ファイル: coref.py プロジェクト: k1c/allennlp-models-1
    def _coarse_to_fine_pruning(
        self,
        top_span_embeddings: torch.FloatTensor,
        top_span_mention_scores: torch.FloatTensor,
        top_span_mask: torch.BoolTensor,
        max_antecedents: int,
    ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor,
               torch.LongTensor]:
        """
        Generates antecedents for each span and prunes down to `max_antecedents`. This method
        prunes antecedents using a fast bilinar interaction score between a span and a candidate
        antecedent, and the highest-scoring antecedents are kept.

        # Parameters

        top_span_embeddings: torch.FloatTensor, required.
            The embeddings of the top spans.
            (batch_size, num_spans_to_keep, embedding_size).
        top_span_mention_scores: torch.FloatTensor, required.
            The mention scores of the top spans.
            (batch_size, num_spans_to_keep).
        top_span_mask: torch.BoolTensor, required.
            The mask for the top spans.
            (batch_size, num_spans_to_keep).
        max_antecedents: int, required.
            The maximum number of antecedents to keep for each span.

        # Returns

        top_partial_coreference_scores: torch.FloatTensor
            The partial antecedent scores for each span-antecedent pair. Computed by summing
            the span mentions scores of the span and the antecedent as well as a bilinear
            interaction term. This score is partial because compared to the full coreference scores,
            it lacks the interaction term
            w * FFNN([g_i, g_j, g_i * g_j, features]).
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_mask: torch.BoolTensor
            The mask representing whether each antecedent span is valid. Required since
            different spans have different numbers of valid antecedents. For example, the first
            span in the document should have no valid antecedents.
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_offsets: torch.LongTensor
            The distance between the span and each of its antecedents in terms of the number
            of considered spans (i.e not the word distance between the spans).
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_indices: torch.LongTensor
            The indices of every antecedent to consider with respect to the top k spans.
            (batch_size, num_spans_to_keep, max_antecedents)
        """
        batch_size, num_spans_to_keep = top_span_embeddings.size()[:2]
        device = util.get_device_of(top_span_embeddings)

        # Shape: (1, num_spans_to_keep, num_spans_to_keep)
        _, _, valid_antecedent_mask = self._generate_valid_antecedents(
            num_spans_to_keep, num_spans_to_keep, device)

        mention_one_score = top_span_mention_scores.unsqueeze(1)
        mention_two_score = top_span_mention_scores.unsqueeze(2)
        bilinear_weights = self._coarse2fine_scorer(
            top_span_embeddings).transpose(1, 2)
        bilinear_score = torch.matmul(top_span_embeddings, bilinear_weights)
        # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
        partial_antecedent_scores = mention_one_score + mention_two_score + bilinear_score

        # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
        span_pair_mask = top_span_mask.unsqueeze(-1) & valid_antecedent_mask

        # Shape:
        # (batch_size, num_spans_to_keep, max_antecedents) * 3
        (
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_indices,
        ) = util.masked_topk(partial_antecedent_scores, span_pair_mask,
                             max_antecedents)

        top_span_range = util.get_range_vector(num_spans_to_keep, device)
        # Shape: (num_spans_to_keep, num_spans_to_keep); broadcast op
        valid_antecedent_offsets = top_span_range.unsqueeze(
            -1) - top_span_range.unsqueeze(0)

        # TODO: we need to make `batched_index_select` more general to make this less awkward.
        top_antecedent_offsets = util.batched_index_select(
            valid_antecedent_offsets.unsqueeze(0).expand(
                batch_size, num_spans_to_keep,
                num_spans_to_keep).reshape(batch_size * num_spans_to_keep,
                                           num_spans_to_keep, 1),
            top_antecedent_indices.view(-1, max_antecedents),
        ).reshape(batch_size, num_spans_to_keep, max_antecedents)

        return (
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
            top_antecedent_indices,
        )
コード例 #2
0
ファイル: bimpm_matching.py プロジェクト: PYART0/PyART-demo
    def forward(
        self,
        context_1: torch.Tensor,
        mask_1: torch.BoolTensor,
        context_2: torch.Tensor,
        mask_2: torch.BoolTensor,
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral
        matching functions between them in one direction.

        # Parameters

        context_1 : `torch.Tensor`
            Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence.
        mask_1 : `torch.BoolTensor`
            Boolean Tensor of shape (batch_size, seq_len1), indicating which
            positions in the first sentence are padding (0) and which are not (1).
        context_2 : `torch.Tensor`
            Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence.
        mask_2 : `torch.BoolTensor`
            Boolean Tensor of shape (batch_size, seq_len2), indicating which
            positions in the second sentence are padding (0) and which are not (1).

        # Returns

        `Tuple[List[torch.Tensor], List[torch.Tensor]]` :
            A tuple of matching vectors for the two sentences. Each of which is a list of
            matching vectors of shape (batch, seq_len, num_perspectives or 1)
        """
        assert (not mask_2.requires_grad) and (not mask_1.requires_grad)
        assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim

        # (batch,)
        len_1 = get_lengths_from_binary_sequence_mask(mask_1)
        len_2 = get_lengths_from_binary_sequence_mask(mask_2)

        # explicitly set masked weights to zero
        # (batch_size, seq_len*, hidden_dim)
        context_1 = context_1 * mask_1.unsqueeze(-1)
        context_2 = context_2 * mask_2.unsqueeze(-1)

        # array to keep the matching vectors for the two sentences
        matching_vector_1: List[torch.Tensor] = []
        matching_vector_2: List[torch.Tensor] = []

        # Step 0. unweighted cosine
        # First calculate the cosine similarities between each forward
        # (or backward) contextual embedding and every forward (or backward)
        # contextual embedding of the other sentence.

        # (batch, seq_len1, seq_len2)
        cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2),
                                         context_2.unsqueeze(-3),
                                         dim=3)

        # (batch, seq_len*, 1)
        cosine_max_1 = masked_max(cosine_sim,
                                  mask_2.unsqueeze(-2),
                                  dim=2,
                                  keepdim=True)
        cosine_mean_1 = masked_mean(cosine_sim,
                                    mask_2.unsqueeze(-2),
                                    dim=2,
                                    keepdim=True)
        cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1),
                                  mask_1.unsqueeze(-2),
                                  dim=2,
                                  keepdim=True)
        cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1),
                                    mask_1.unsqueeze(-2),
                                    dim=2,
                                    keepdim=True)

        matching_vector_1.extend([cosine_max_1, cosine_mean_1])
        matching_vector_2.extend([cosine_max_2, cosine_mean_2])

        # Step 1. Full-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with the last time step of the forward (or backward)
        # contextual embedding of the other sentence
        if self.with_full_match:

            # (batch, 1, hidden_dim)
            if self.is_forward:
                # (batch, 1, hidden_dim)
                last_position_1 = (len_1 - 1).clamp(min=0)
                last_position_1 = last_position_1.view(-1, 1, 1).expand(
                    -1, 1, self.hidden_dim)
                last_position_2 = (len_2 - 1).clamp(min=0)
                last_position_2 = last_position_2.view(-1, 1, 1).expand(
                    -1, 1, self.hidden_dim)

                context_1_last = context_1.gather(1, last_position_1)
                context_2_last = context_2.gather(1, last_position_2)
            else:
                context_1_last = context_1[:, 0:1, :]
                context_2_last = context_2[:, 0:1, :]

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_full = multi_perspective_match(
                context_1, context_2_last, self.full_match_weights)
            matching_vector_2_full = multi_perspective_match(
                context_2, context_1_last, self.full_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_full)
            matching_vector_2.extend(matching_vector_2_full)

        # Step 2. Maxpooling-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with every time step of the forward (or backward)
        # contextual embedding of the other sentence, and only the max value of each
        # dimension is retained.
        if self.with_maxpool_match:
            # (batch, seq_len1, seq_len2, num_perspectives)
            matching_vector_max = multi_perspective_match_pairwise(
                context_1, context_2, self.maxpool_match_weights)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_max = masked_max(
                matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            matching_vector_1_mean = masked_mean(
                matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            matching_vector_2_max = masked_max(
                matching_vector_max.permute(0, 2, 1, 3),
                mask_1.unsqueeze(-2).unsqueeze(-1),
                dim=2)
            matching_vector_2_mean = masked_mean(
                matching_vector_max.permute(0, 2, 1, 3),
                mask_1.unsqueeze(-2).unsqueeze(-1),
                dim=2)

            matching_vector_1.extend(
                [matching_vector_1_max, matching_vector_1_mean])
            matching_vector_2.extend(
                [matching_vector_2_max, matching_vector_2_mean])

        # Step 3. Attentive-Matching
        # Each forward (or backward) similarity is taken as the weight
        # of the forward (or backward) contextual embedding, and calculate an
        # attentive vector for the sentence by weighted summing all its
        # contextual embeddings.
        # Finally match each forward (or backward) contextual embedding
        # with its corresponding attentive vector.

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1)

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1)

        if self.with_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1))
            att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1))

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_mean = multi_perspective_match(
                context_1, att_mean_2, self.attentive_match_weights)
            matching_vector_2_att_mean = multi_perspective_match(
                context_2, att_mean_1, self.attentive_match_weights_reversed)
            matching_vector_1.extend(matching_vector_1_att_mean)
            matching_vector_2.extend(matching_vector_2_att_mean)

        # Step 4. Max-Attentive-Matching
        # Pick the contextual embeddings with the highest cosine similarity as the attentive
        # vector, and match each forward (or backward) contextual embedding with its
        # corresponding attentive vector.
        if self.with_max_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_max_2 = masked_max(att_2,
                                   mask_2.unsqueeze(-2).unsqueeze(-1),
                                   dim=2)
            att_max_1 = masked_max(att_1.permute(0, 2, 1, 3),
                                   mask_1.unsqueeze(-2).unsqueeze(-1),
                                   dim=2)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_max = multi_perspective_match(
                context_1, att_max_2, self.max_attentive_match_weights)
            matching_vector_2_att_max = multi_perspective_match(
                context_2, att_max_1,
                self.max_attentive_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_att_max)
            matching_vector_2.extend(matching_vector_2_att_max)

        return matching_vector_1, matching_vector_2
コード例 #3
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        mrc_mask: torch.BoolTensor,
        parent_idxs: torch.LongTensor = None,
        parent_tags: torch.LongTensor = None,
        # is_subtree: torch.BoolTensor = None
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            mrc_mask: [batch_size, num_words]
            parent_idxs: [batch_size]
            parent_tags: [batch_size]
            # is_subtree: [batch_size]
        Returns:
            # is_subtree_probs: [batch_size]
            parent_probs: [batch_size, num_word]
            parent_tag_probs: [batch_size, num_words, num_tags]
            # subtree_loss(if is_subtree is not None)
            arc_loss (if parent_idx is not None)
            tag_loss (if parent_idxs and parent_tags are not None)
        """

        cls_embedding, embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)
        cls_embedding = self._dropout(cls_embedding)

        # [bsz]
        # subtree_scores = self.is_subtree_feedforward(cls_embedding).squeeze(-1)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)

        # mask out impossible positions
        minus_inf = -1e8
        mrc_mask = torch.logical_and(mrc_mask, word_mask)
        parent_scores = parent_scores + (~mrc_mask).float() * minus_inf

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)

        # output = (torch.sigmoid(subtree_scores), parent_probs, parent_tag_probs)  # todo check if log in dp evaluation
        output = (parent_probs, parent_tag_probs
                  )  # todo check if log in dp evaluation

        # add losses
        # if is_subtree is not None:
        #     subtree_loss = F.binary_cross_entropy_with_logits(subtree_scores, is_subtree.float())
        #     output = output + (subtree_loss, )
        # else:
        is_subtree = torch.ones_like(parent_tags).bool()

        if parent_idxs is not None:
            sample_mask = is_subtree.float()
            # [bsz]
            batch_range_vector = get_range_vector(batch_size,
                                                  get_device_of(encoded_text))
            # [bsz, seq_len]
            parent_logits = F.log_softmax(parent_scores, dim=-1)
            parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs]
            parent_arc_nll = (parent_arc_nll *
                              sample_mask).sum() / (sample_mask.sum() + 1e-8)
            output = output + (parent_arc_nll, )

            if parent_tags is not None:
                parent_tag_nll = F.cross_entropy(
                    parent_tag_scores[batch_range_vector, parent_idxs],
                    parent_tags,
                    reduction="none")
                parent_tag_nll = (parent_tag_nll * sample_mask).sum() / (
                    sample_mask.sum() + 1e-8)
                output = output + (parent_tag_nll, )

        return output
コード例 #4
0
    def sort_and_run_forward(
            self,
            module,
            inputs: torch.Tensor,
            mask: torch.BoolTensor,
            hidden_state = None,
    ):
        """
        This function exists because Pytorch RNNs require that their inputs be sorted
        before being passed as input. As all of our Seq2xxxEncoders use this functionality,
        it is provided in a base class. This method can be called on any module which
        takes as input a `PackedSequence` and some `hidden_state`, which can either be a
        tuple of tensors or a tensor.
        As all of our Seq2xxxEncoders have different return types, we return `sorted`
        outputs from the module, which is called directly. Additionally, we return the
        indices into the batch dimension required to restore the tensor to it's correct,
        unsorted order and the number of valid batch elements (i.e the number of elements
        in the batch which are not completely masked). This un-sorting and re-padding
        of the module outputs is left to the subclasses because their outputs have different
        types and handling them smoothly here is difficult.
        # Parameters
        module : `Callable[RnnInputs, RnnOutputs]`
            A function to run on the inputs, where
            `RnnInputs: [PackedSequence, Optional[RnnState]]` and
            `RnnOutputs: Tuple[Union[PackedSequence, torch.Tensor], RnnState]`.
            In most cases, this is a `torch.nn.Module`.
        inputs : `torch.Tensor`, required.
            A tensor of shape `(batch_size, sequence_length, embedding_size)` representing
            the inputs to the Encoder.
        mask : `torch.BoolTensor`, required.
            A tensor of shape `(batch_size, sequence_length)`, representing masked and
            non-masked elements of the sequence for each element in the batch.
        hidden_state : `Optional[RnnState]`, (default = `None`).
            A single tensor of shape (num_layers, batch_size, hidden_size) representing the
            state of an RNN with or a tuple of
            tensors of shapes (num_layers, batch_size, hidden_size) and
            (num_layers, batch_size, memory_size), representing the hidden state and memory
            state of an LSTM-like RNN.
        # Returns
        module_output : `Union[torch.Tensor, PackedSequence]`.
            A Tensor or PackedSequence representing the output of the Pytorch Module.
            The batch size dimension will be equal to `num_valid`, as sequences of zero
            length are clipped off before the module is called, as Pytorch cannot handle
            zero length sequences.
        final_states : `Optional[RnnState]`
            A Tensor representing the hidden state of the Pytorch Module. This can either
            be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in
            the case of a GRU, or a tuple of tensors, such as those required for an LSTM.
        restoration_indices : `torch.LongTensor`
            A tensor of shape `(batch_size,)`, describing the re-indexing required to transform
            the outputs back to their original batch order.
        """
        # In some circumstances you may have sequences of zero length. `pack_padded_sequence`
        # requires all sequence lengths to be > 0, so remove sequences of zero length before
        # calling self._module, then fill with zeros.

        # First count how many sequences are empty.
        batch_size = mask.size(0)
        num_valid = torch.sum(mask[:, 0]).int().item()

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        (
            sorted_inputs,
            sorted_sequence_lengths,
            restoration_indices,
            sorting_indices,
        ) = sort_batch_by_length(inputs, sequence_lengths)

        # Now create a PackedSequence with only the non-empty, sorted sequences.
        packed_sequence_input = pack_padded_sequence(
            sorted_inputs[:num_valid, :, :],
            sorted_sequence_lengths[:num_valid].data.tolist(),
            batch_first=True,
        )
        # Prepare the initial states.
        if True:
            if hidden_state is None:
                initial_states = hidden_state
            elif isinstance(hidden_state, tuple):
                initial_states = [
                    state.index_select(1, sorting_indices)[:, :num_valid, :].contiguous()
                    for state in hidden_state
                ]
            else:
                initial_states = hidden_state.index_select(1, sorting_indices)[
                                 :, :num_valid, :
                                 ].contiguous()

        else:
            initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices)

        # Actually call the module on the sorted PackedSequence.
        module_output, final_states = module(packed_sequence_input, initial_states)

        return module_output, final_states, restoration_indices
コード例 #5
0
    def forward(self, inputs: torch.Tensor,
                mask: torch.BoolTensor) -> torch.Tensor:
        """
        # Parameters

        inputs : `torch.Tensor`, required.
            A Tensor of shape `(batch_size, sequence_length, hidden_size)`.
        mask : `torch.BoolTensor`, required.
            A binary mask of shape `(batch_size, sequence_length)` representing the
            non-padded elements in each sequence in the batch.

        # Returns

        `torch.Tensor`
            A `torch.Tensor` of shape (num_layers, batch_size, sequence_length, hidden_size),
            where the num_layers dimension represents the LSTM output from that layer.
        """
        batch_size, total_sequence_length = mask.size()
        stacked_sequence_output, final_states, restoration_indices = self.sort_and_run_forward(
            self._lstm_forward, inputs, mask)

        num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size(
        )
        # Add back invalid rows which were removed in the call to sort_and_run_forward.
        if num_valid < batch_size:
            zeros = stacked_sequence_output.new_zeros(num_layers,
                                                      batch_size - num_valid,
                                                      returned_timesteps,
                                                      encoder_dim)
            stacked_sequence_output = torch.cat(
                [stacked_sequence_output, zeros], 1)

            # The states also need to have invalid rows added back.
            new_states = []
            for state in final_states:
                state_dim = state.size(-1)
                zeros = state.new_zeros(num_layers, batch_size - num_valid,
                                        state_dim)
                new_states.append(torch.cat([state, zeros], 1))
            final_states = new_states

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2StackEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - returned_timesteps
        if sequence_length_difference > 0:
            zeros = stacked_sequence_output.new_zeros(
                num_layers,
                batch_size,
                sequence_length_difference,
                stacked_sequence_output[0].size(-1),
            )
            stacked_sequence_output = torch.cat(
                [stacked_sequence_output, zeros], 2)

        self._update_states(final_states, restoration_indices)

        # Restore the original indices and return the sequence.
        # Has shape (num_layers, batch_size, sequence_length, hidden_size)
        return stacked_sequence_output.index_select(1, restoration_indices)
コード例 #6
0
ファイル: ud.py プロジェクト: lgessler/embur
    def _parse(
        self,
        encoded_text: torch.Tensor,
        mask: torch.BoolTensor,
        deprel_labels: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if deprel_labels is not None:
            deprel_labels = torch.cat([deprel_labels.new_zeros(batch_size, 1), deprel_labels], 1)
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation)

        minus_inf = -1e8
        minus_mask = ~mask * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_deprel_labels = self._greedy_decode(
                head_tag_representation, child_tag_representation, attended_arcs, mask
            )
        else:
            predicted_heads, predicted_deprel_labels = self._mst_decode(
                head_tag_representation, child_tag_representation, attended_arcs, mask
            )
        if head_indices is not None and deprel_labels is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                deprel_labels=deprel_labels,
                mask=mask,
            )
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                deprel_labels=predicted_deprel_labels.long(),
                mask=mask,
            )

        return predicted_heads, predicted_deprel_labels, mask, arc_nll, tag_nll
コード例 #7
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        span_idx: torch.LongTensor,
        span_tag: torch.LongTensor,
        child_arcs: torch.LongTensor,
        child_tags: torch.LongTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        mrc_mask: torch.BoolTensor,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            span_idx: [batch_size, 2]
            span_tag: [batch_size, 1]
            child_arcs: [batch_size, num_words]
            child_tags: [batch_size, num_words]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            mrc_mask: [batch_size, num_words]
        Returns:
            parent_probs: [batch_size, num_word]
            parent_tag_probs: [batch_size, num_words]
            arc_nll: [1]
            tag_nll: [1]
        """

        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)

        # [bsz, seq_len, tag_classes]
        child_tag_scores = self.child_tag_feedforward(encoded_text)
        # [bsz, seq_len]
        child_scores = self.child_feedforward(encoded_text).squeeze(-1)

        # todo support cases that span_idx and span_tag are None
        # [bsz]
        batch_range_vector = get_range_vector(batch_size,
                                              get_device_of(encoded_text))
        # [bsz]
        gold_positions = span_idx[:, 0]

        # compute parent arc loss
        minus_inf = -1e8
        mrc_mask = torch.logical_and(mrc_mask, word_mask)
        parent_scores = parent_scores + (~mrc_mask).float() * minus_inf
        child_scores = child_scores + (~mrc_mask).float() * minus_inf

        # [bsz, seq_len]
        parent_logits = F.log_softmax(parent_scores, dim=-1)
        parent_arc_nll = -parent_logits[batch_range_vector,
                                        gold_positions].mean()

        # compute parent tag loss
        parent_tag_nll = F.cross_entropy(
            parent_tag_scores[batch_range_vector, gold_positions], span_tag)

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)
        child_probs = F.sigmoid(child_scores)
        child_tag_probs = F.softmax(child_tag_scores, dim=-1)

        child_arc_loss = F.binary_cross_entropy_with_logits(child_scores,
                                                            child_arcs.float(),
                                                            reduction="none")
        child_arc_loss = (child_arc_loss *
                          mrc_mask.float()).sum() / mrc_mask.float().sum()
        child_tag_loss = F.cross_entropy(child_tag_scores.view(
            batch_size * seq_len, -1),
                                         child_tags.view(-1),
                                         reduction="none")
        child_tag_loss = (child_tag_loss * child_arcs.float().view(-1)
                          ).sum() / (child_arcs.float().sum() + 1e-8)

        return parent_probs, parent_tag_probs, child_probs, child_tag_probs, parent_arc_nll, parent_tag_nll, child_arc_loss, child_tag_loss
コード例 #8
0
def debug_to_set(t : torch.BoolTensor) -> Set[int]:
    return [{ i for i,x in enumerate(batch) if x} for batch in t.numpy()]
コード例 #9
0
    def forward(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        sequence_mask: torch.BoolTensor = None,
        span_indices_mask: torch.BoolTensor = None,
    ) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(int(
            self._input_dim / 2),
                                                                    dim=-1)
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [
            index.squeeze(-1) for index in span_indices.split(1, dim=-1)
        ]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP `SpanField` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (exclusive_span_starts == -1).unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP `SpanField` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(
                sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = torch.ones_like(
                sequence_tensor[:, 0,
                                0], dtype=torch.long) * sequence_tensor.size(1)

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (exclusive_span_ends >=
                             sequence_lengths.unsqueeze(-1)).unsqueeze(-1)

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * ~end_sentinel_mask.squeeze(
            -1)
        exclusive_span_starts = exclusive_span_starts * ~start_sentinel_mask.squeeze(
            -1)

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (
                exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
            raise ValueError(
                f"Adjusted span indices must lie inside the length of the sequence tensor, "
                f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                f"{sequence_lengths}.")

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(
            forward_sequence, exclusive_span_starts)
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(
            forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(
            backward_sequence, exclusive_span_ends)
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(
            backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            forward_start_embeddings = (
                forward_start_embeddings * ~start_sentinel_mask +
                start_sentinel_mask * self._start_sentinel)
            backward_start_embeddings = (
                backward_start_embeddings * ~end_sentinel_mask +
                end_sentinel_mask * self._end_sentinel)

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(
            self._forward_combination,
            [forward_start_embeddings, forward_end_embeddings])
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(
            self._backward_combination,
            [backward_start_embeddings, backward_end_embeddings])
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(
                    span_ends - span_starts,
                    num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.unsqueeze(-1)
        return span_embeddings
コード例 #10
0
    def _viterbi_decode(
        self, emissions: torch.FloatTensor, mask: torch.BoolTensor
    ) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[: seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list
コード例 #11
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        subtree_spans: torch.LongTensor = None,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            subtree_spans: [batch_size, num_words, 2]
        Returns:
            span_start_logits: [batch_size, num_words, num_words]
            span_end_logits: [batch_size, num_words, num_words]

        """
        # [bsz, seq_len, hidden]
        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # [bsz, seq_len, dim]
        subtree_start_representation = self._dropout(
            self.subtree_start_feedforward(encoded_text))
        subtree_end_representation = self._dropout(
            self.subtree_end_feedforward(encoded_text))
        # [bsz, seq_len, seq_len]
        span_start_scores = self.subtree_start_attention(
            subtree_start_representation, subtree_start_representation)
        span_end_scores = self.subtree_end_attention(
            subtree_end_representation, subtree_end_representation)

        # start of word should less equal to it
        start_mask = word_mask.unsqueeze(-1) & (
            ~torch.triu(span_start_scores.bool(), 1))
        # end of word should greater equal to it.
        end_mask = word_mask.unsqueeze(-1) & torch.triu(span_end_scores.bool())

        minus_inf = -1e8
        span_start_scores = span_start_scores + (
            ~start_mask).float() * minus_inf
        span_end_scores = span_end_scores + (~end_mask).float() * minus_inf

        output = (F.log_softmax(span_start_scores,
                                dim=-1), F.log_softmax(span_end_scores,
                                                       dim=-1))

        if subtree_spans is not None:

            start_loss = F.cross_entropy(
                span_start_scores.view(batch_size * seq_len, -1),
                subtree_spans[:, :, 0].view(-1))
            end_loss = F.cross_entropy(
                span_end_scores.view(batch_size * seq_len, -1),
                subtree_spans[:, :, 1].view(-1))
            span_loss = start_loss + end_loss
            output = output + (span_loss, )

        return output
コード例 #12
0
def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor,
                         mask: torch.BoolTensor) -> torch.Tensor:
    pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log()
    return F.cross_entropy(pred, true, reduction="none") * mask
コード例 #13
0
    def forward(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        sequence_mask: torch.BoolTensor = None,
        span_indices_mask: torch.BoolTensor = None,
    ):
        """
        Given a sequence tensor, extract spans, concatenate width embeddings
        when need and return representations of them.

        # Parameters

        sequence_tensor : `torch.FloatTensor`, required.
            A tensor of shape (batch_size, sequence_length, embedding_size)
            representing an embedded sequence of words.
        span_indices : `torch.LongTensor`, required.
            A tensor of shape `(batch_size, num_spans, 2)`, where the last
            dimension represents the inclusive start and end indices of the
            span to be extracted from the `sequence_tensor`.
        sequence_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, sequence_length) representing padded
            elements of the sequence.
        span_indices_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, num_spans) representing the valid
            spans in the `indices` tensor. This mask is optional because
            sometimes it's easier to worry about masking after calling this
            function, rather than passing a mask directly.

        # Returns

        A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
        where `embedded_span_size` depends on the way spans are represented.
        """
        # shape (batch_size, num_spans, embedding_dim)
        span_embeddings = self._embed_spans(sequence_tensor, span_indices,
                                            sequence_mask, span_indices_mask)
        if self._span_width_embedding is not None:
            # width = end_index - start_index + 1 since `SpanField` use inclusive indices.
            # But here we do not add 1 beacuse we often initiate the span width
            # embedding matrix with `num_width_embeddings = max_span_width`
            # shape (batch_size, num_spans)
            widths_minus_one = span_indices[..., 1] - span_indices[..., 0]

            if self._bucket_widths:
                widths_minus_one = util.bucket_values(
                    widths_minus_one,
                    num_total_buckets=self.
                    _num_width_embeddings  # type: ignore
                )

            # Embed the span widths and concatenate to the rest of the representations.
            span_width_embeddings = self._span_width_embedding(
                widths_minus_one)
            span_embeddings = torch.cat(
                [span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            # Here we are masking the spans which were originally passed in as padding.
            return span_embeddings * span_indices_mask.unsqueeze(-1)

        return span_embeddings
コード例 #14
0
    def _parse(
        self,
        embedded_text_input: torch.Tensor,
        mask: torch.BoolTensor,
        head_tags: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, sequence_length, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self.head_arc_feedforward(encoded_text)
        child_arc_representation = self.child_arc_feedforward(encoded_text)

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self.head_tag_feedforward(encoded_text)
        child_tag_representation = self.child_tag_feedforward(encoded_text)

        # calculate dimensions again as sequence_length is now + 1 from adding the head_sentinel
        batch_size, sequence_length, arc_dim = head_arc_representation.size()
        
        # now repeat the token representations to form a matrix:
        # shape (batch_size, sequence_length, sequence_length, arc_representation_dim)
        heads = head_arc_representation.repeat(1, sequence_length, 1).reshape(batch_size, sequence_length, sequence_length, arc_dim) # heads in one direction
        deps = child_arc_representation.repeat(1, sequence_length, 1).reshape(batch_size, sequence_length, sequence_length, arc_dim).transpose(1, 2) # deps in the other direction  
        
        # shape (batch_size, sequence_length, sequence_length, arc_representation_dim)
        combined_arcs = self.activation(heads + deps)

        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_out_layer(combined_arcs).squeeze(3)
        
        minus_inf = -1e8
        minus_mask = ~mask * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation, attended_arcs, mask
            )
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation, attended_arcs, mask
            )
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask,
            )
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask,
            )

        return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll
コード例 #15
0
ファイル: tmp.py プロジェクト: PYART0/PyART-demo
def masked_log_softmax(vector: torch.Tensor, mask: torch.BoolTensor, dim: int = -1) -> torch.Tensor:
    if mask is not None:
        while mask.dim() < vector.dim():
            mask = mask.unsqueeze(1)
        vector = vector + (mask + tiny_value_of_dtype(vector.dtype)).log()
    return torch.nn.functional.log_softmax(vector, dim=dim)
コード例 #16
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        dep_idxs: torch.LongTensor,
        dep_tags: torch.LongTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
    ):

        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, _, encoding_dim = encoded_text.size()
        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        word_mask = torch.cat([word_mask.new_ones(batch_size, 1), word_mask],
                              1)
        dep_idxs = torch.cat([dep_idxs.new_zeros(batch_size, 1), dep_idxs], 1)
        dep_tags = torch.cat([dep_tags.new_zeros(batch_size, 1), dep_tags], 1)

        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = ~word_mask * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, word_mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, word_mask)

        arc_nll, tag_nll = self._construct_loss(
            head_tag_representation=head_tag_representation,
            child_tag_representation=child_tag_representation,
            attended_arcs=attended_arcs,
            head_indices=dep_idxs,
            head_tags=dep_tags,
            mask=word_mask,
        )

        return predicted_heads, predicted_head_tags, arc_nll, tag_nll
コード例 #17
0
ファイル: tmp.py プロジェクト: PYART0/PyART-demo
def get_lengths_from_binary_sequence_mask(mask: torch.BoolTensor) -> torch.LongTensor:
    return mask.sum(-1)
コード例 #18
0
ファイル: evaluation.py プロジェクト: kmkurn/ppt-eacl2021
def count_correct(
    heads: LongTensor,
    types: LongTensor,
    pred_heads: LongTensor,
    pred_types: LongTensor,
    mask: BoolTensor,
    nopunct_mask: BoolTensor,
    proj_mask: BoolTensor,
    root_idx: int = 0,
    type_idx: Optional[int] = None,
) -> Union["Counts", "TypeWiseCounts"]:
    # shape: (bsz, slen)
    assert heads.dim() == 2
    assert types.shape == heads.shape
    assert pred_heads.shape == heads.shape
    assert pred_types.shape == heads.shape
    assert mask.shape == heads.shape
    assert nopunct_mask.shape == heads.shape
    assert proj_mask.shape == heads.shape

    corr_heads = heads == pred_heads
    corr_types = types == pred_types

    if type_idx is None:
        root_mask = heads == root_idx
        nonproj_mask = ~torch.all(proj_mask | (~mask), dim=1, keepdim=True)

        usents = int(torch.all(corr_heads | (~mask), dim=1).long().sum())
        usents_nopunct = int(
            torch.all(corr_heads | (~mask) | (~nopunct_mask),
                      dim=1).long().sum())
        lsents = int(
            torch.all(corr_heads & corr_types | (~mask), dim=1).long().sum())
        lsents_nopunct = int(
            torch.all(corr_heads & corr_types | (~mask) | (~nopunct_mask),
                      dim=1).long().sum())
        uarcs = int((corr_heads & mask).long().sum())
        uarcs_nopunct = int((corr_heads & mask & nopunct_mask).long().sum())
        uarcs_nonproj = int((corr_heads & mask & nonproj_mask).long().sum())
        larcs = int((corr_heads & corr_types & mask).long().sum())
        larcs_nopunct = int(
            (corr_heads & corr_types & mask & nopunct_mask).long().sum())
        larcs_nonproj = int(
            (corr_heads & corr_types & mask & nonproj_mask).long().sum())
        roots = int((corr_heads & mask & root_mask).long().sum())
        n_sents = heads.size(0)
        n_arcs = int(mask.long().sum())
        n_arcs_nopunct = int((mask & nopunct_mask).long().sum())
        n_arcs_nonproj = int((mask & nonproj_mask).long().sum())
        n_roots = int((mask & root_mask).long().sum())

        return Counts(
            usents,
            usents_nopunct,
            lsents,
            lsents_nopunct,
            uarcs,
            uarcs_nopunct,
            uarcs_nonproj,
            larcs,
            larcs_nopunct,
            larcs_nonproj,
            roots,
            n_sents,
            n_arcs,
            n_arcs_nopunct,
            n_arcs_nonproj,
            n_roots,
        )

    assert type_idx is not None
    type_mask = types == type_idx

    uarcs = int((corr_heads & type_mask & mask).long().sum())
    uarcs_nopunct = int(
        (corr_heads & type_mask & nopunct_mask & mask).long().sum())
    larcs = int((corr_heads & corr_types & type_mask & mask).long().sum())
    larcs_nopunct = int((corr_heads & corr_types & type_mask & nopunct_mask
                         & mask).long().sum())
    n_arcs = int((type_mask & mask).long().sum())
    n_arcs_nopunct = int((type_mask & nopunct_mask & mask).long().sum())

    return TypeWiseCounts(type_idx, uarcs, uarcs_nopunct, larcs, larcs_nopunct,
                          n_arcs, n_arcs_nopunct)
コード例 #19
0
ファイル: ud.py プロジェクト: lgessler/embur
    def _construct_loss(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        head_indices: torch.Tensor,
        deprel_labels: torch.Tensor,
        mask: torch.BoolTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        deprel_labels : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : `torch.BoolTensor`, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        # Returns

        arc_nll : `torch.Tensor`, required.
            The negative log likelihood from the arc loss.
        tag_nll : `torch.Tensor`, required.
            The negative log likelihood from the arc tag loss.
        """
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = (
            masked_log_softmax(attended_arcs, mask) * mask.unsqueeze(2) * mask.unsqueeze(1)
        )

        # shape (batch_size, sequence_length, num_deprel_labels)
        head_tag_logits = self._get_deprel_labels(
            head_tag_representation, child_tag_representation, head_indices
        )
        normalised_head_tag_logits = masked_log_softmax(
            head_tag_logits, mask.unsqueeze(-1)
        ) * mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = (
            timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        )
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, deprel_labels]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
コード例 #20
0
    def forward(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        sequence_mask: torch.BoolTensor = None,
        span_indices_mask: torch.BoolTensor = None,
    ) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [
            index.squeeze(-1) for index in span_indices.split(1, dim=-1)
        ]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        if not self._use_exclusive_start_indices:
            if sequence_tensor.size(-1) != self._input_dim:
                raise ValueError(
                    f"Dimension mismatch expected ({sequence_tensor.size(-1)}) "
                    f"received ({self._input_dim}).")
            start_embeddings = util.batched_index_select(
                sequence_tensor, span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor,
                                                       span_ends)

        else:
            # We want `exclusive` span starts, so we remove 1 from the forward span starts
            # as the AllenNLP `SpanField` is inclusive.
            # shape (batch_size, num_spans)
            exclusive_span_starts = span_starts - 1
            # shape (batch_size, num_spans, 1)
            start_sentinel_mask = (exclusive_span_starts == -1).unsqueeze(-1)
            exclusive_span_starts = exclusive_span_starts * ~start_sentinel_mask.squeeze(
                -1)

            # We'll check the indices here at runtime, because it's difficult to debug
            # if this goes wrong and it's tricky to get right.
            if (exclusive_span_starts < 0).any():
                raise ValueError(
                    f"Adjusted span indices must lie inside the the sequence tensor, "
                    f"but found: exclusive_span_starts: {exclusive_span_starts}."
                )

            start_embeddings = util.batched_index_select(
                sequence_tensor, exclusive_span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor,
                                                       span_ends)

            # We're using sentinels, so we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with the start sentinel.
            start_embeddings = (start_embeddings * ~start_sentinel_mask +
                                start_sentinel_mask * self._start_sentinel)

        combined_tensors = util.combine_tensors(
            self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(
                    span_ends - span_starts,
                    num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            combined_tensors = torch.cat(
                [combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1)

        return combined_tensors
コード例 #21
0
    def forward(self,
                inputs: torch.Tensor,
                mask: torch.BoolTensor,
                hidden_state: torch.Tensor = None) -> torch.Tensor:

        if self.stateful and mask is None:
            raise ValueError("Always pass a mask with stateful RNNs.")
        if self.stateful and hidden_state is not None:
            raise ValueError(
                "Stateful RNNs provide their own initial hidden_state.")

        if mask is None:
            return self._module(inputs, hidden_state)[0]

        batch_size, total_sequence_length = mask.size()

        packed_sequence_output, final_states, restoration_indices = self.sort_and_run_forward(
            self._module, inputs, mask, hidden_state)

        unpacked_sequence_tensor, _ = pad_packed_sequence(
            packed_sequence_output, batch_first=True)

        num_valid = unpacked_sequence_tensor.size(0)
        # Some RNNs (GRUs) only return one state as a Tensor.  Others (LSTMs) return two.
        # If one state, use a single element list to handle in a consistent manner below.
        if not isinstance(final_states, (list, tuple)) and self.stateful:
            final_states = [final_states]

        # Add back invalid rows.
        if num_valid < batch_size:
            _, length, output_dim = unpacked_sequence_tensor.size()
            zeros = unpacked_sequence_tensor.new_zeros(batch_size - num_valid,
                                                       length, output_dim)
            unpacked_sequence_tensor = torch.cat(
                [unpacked_sequence_tensor, zeros], 0)

            # The states also need to have invalid rows added back.
            if self.stateful:
                new_states = []
                for state in final_states:
                    num_layers, _, state_dim = state.size()
                    zeros = state.new_zeros(num_layers, batch_size - num_valid,
                                            state_dim)
                    new_states.append(torch.cat([state, zeros], 1))
                final_states = new_states

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2SeqEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - unpacked_sequence_tensor.size(
            1)
        if sequence_length_difference > 0:
            zeros = unpacked_sequence_tensor.new_zeros(
                batch_size, sequence_length_difference,
                unpacked_sequence_tensor.size(-1))
            unpacked_sequence_tensor = torch.cat(
                [unpacked_sequence_tensor, zeros], 1)

        if self.stateful:
            self._update_states(final_states, restoration_indices)

        # Restore the original indices and return the sequence.
        return unpacked_sequence_tensor.index_select(0, restoration_indices)
コード例 #22
0
ファイル: dep.py プロジェクト: wainshine/HanLP
 def feed_batch(self, h: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask: torch.BoolTensor,
                decoder: torch.nn.Module):
     logits = super().feed_batch(h, batch, mask, decoder)
     mask = mask.clone()
     mask[:, 0] = 0
     return logits, mask
コード例 #23
0
ファイル: distillation.py プロジェクト: convobox/ParlAI
    def _get_and_record_component_attention_loss(
        self,
        teacher_attention_matrices: List[Dict[str, torch.Tensor]],
        student_attention_matrices: List[Dict[str, torch.Tensor]],
        mask: torch.BoolTensor,
        tokens_per_example: torch.Tensor,
        num_tokens: torch.Tensor,
        mapped_layers: List[int],
        attn_type: str,
        metric_name: str,
    ) -> torch.Tensor:
        """
        Calculate the given attention loss and register it as the given metric name.
        """

        assert isinstance(self, TorchGeneratorAgent)
        # Code relies on methods

        # Select the right attention matrices
        selected_student_attn_matrices = [
            layer_matrices[attn_type]
            for layer_matrices in student_attention_matrices
        ]
        selected_teacher_attn_matrices = [
            layer_matrices[attn_type]
            for layer_matrices in teacher_attention_matrices
        ]

        batch_size = mask.size(0)
        per_layer_losses = []
        per_layer_per_example_losses = []
        for student_layer_idx, teacher_layer_idx in enumerate(mapped_layers):
            raw_layer_loss = F.mse_loss(
                input=selected_student_attn_matrices[student_layer_idx],
                target=selected_teacher_attn_matrices[teacher_layer_idx],
                reduction='none',
            )
            clamped_layer_loss = torch.clamp(raw_layer_loss,
                                             min=0,
                                             max=NEAR_INF_FP16)
            # Prevent infs from appearing in the loss term. Especially important with
            # fp16
            reshaped_layer_loss = clamped_layer_loss.view(
                batch_size, -1, clamped_layer_loss.size(-2),
                clamped_layer_loss.size(-1))
            # [batch size, n heads, query length, key length]
            mean_layer_loss = reshaped_layer_loss.mean(dim=(1, 3))
            # Take the mean over the attention heads and the key length
            assert mean_layer_loss.shape == mask.shape
            masked_layer_loss = mean_layer_loss * mask
            layer_loss_per_example = masked_layer_loss.sum(
                dim=-1)  # Sum over token dim
            layer_loss = masked_layer_loss.div(num_tokens).sum()
            # Divide before summing over examples so that values don't get too large
            per_layer_losses.append(layer_loss)
            per_layer_per_example_losses.append(layer_loss_per_example)
        attn_loss = torch.stack(per_layer_losses).mean()
        attn_loss_per_example = torch.stack(per_layer_per_example_losses,
                                            dim=1).mean(dim=1)

        # Record metric
        self.record_local_metric(
            metric_name,
            AverageMetric.many(attn_loss_per_example, tokens_per_example))

        return attn_loss
コード例 #24
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        mask_positions: torch.BoolTensor,
        target_ids: TextFieldTensors = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            The output of `TextField.as_tensor()` for a batch of sentences.
        mask_positions : `torch.LongTensor`
            The positions in `tokens` that correspond to [MASK] tokens that we should try to fill
            in.  Shape should be (batch_size, num_masks).
        target_ids : `TextFieldTensors`
            This is a list of token ids that correspond to the mask positions we're trying to fill.
            It is the output of a `TextField`, purely for convenience, so we can handle wordpiece
            tokenizers and such without having to do crazy things in the dataset reader.  We assume
            that there is exactly one entry in the dictionary, and that it has a shape identical to
            `mask_positions` - one target token per mask position.
        """

        targets = None
        if target_ids is not None:
            # A bit of a hack to get the right targets out of the TextField output...
            if len(target_ids) != 1:
                targets = target_ids["bert"]["token_ids"]
            else:
                targets = list(target_ids.values())[0]["tokens"]
        mask_positions = mask_positions.squeeze(-1)
        batch_size, num_masks = mask_positions.size()
        if targets is not None and targets.size() != mask_positions.size():
            raise ValueError(
                f"Number of targets ({targets.size()}) and number of masks "
                f"({mask_positions.size()}) are not equal")

        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
        else:
            contextual_embeddings = embeddings

        # Does advanced indexing to get the embeddings of just the mask positions, which is what
        # we're trying to predict.
        batch_index = torch.arange(0, batch_size).long().unsqueeze(1)
        mask_embeddings = contextual_embeddings[batch_index, mask_positions]

        target_logits = self._language_model_head(
            self._dropout(mask_embeddings))

        vocab_size = target_logits.size(-1)
        probs = torch.nn.functional.softmax(target_logits, dim=-1)
        k = min(vocab_size,
                5)  # min here largely because tests use small vocab
        top_probs, top_indices = probs.topk(k=k, dim=-1)

        output_dict = {"probabilities": top_probs, "top_indices": top_indices}

        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)

        if targets is not None:
            target_logits = target_logits.view(batch_size * num_masks,
                                               vocab_size)
            targets = targets.view(batch_size * num_masks)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        return output_dict
コード例 #25
0
    def _unfold_long_sequences(
        self,
        embeddings: torch.FloatTensor,
        mask: torch.BoolTensor,
        batch_size: int,
        num_segment_concat_wordpieces: int,
    ) -> torch.FloatTensor:
        """
        We take 2D segments of a long sequence and flatten them out to get the whole sequence
        representation while remove unnecessary special tokens.

        [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ]
        -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ]

        We truncate the start and end tokens for all segments, recombine the segments,
        and manually add back the start and end tokens.

        # Parameters

        embeddings: `torch.FloatTensor`
            Shape: [batch_size * num_segments, self._max_length, embedding_size].
        mask: `torch.BoolTensor`
            Shape: [batch_size * num_segments, self._max_length].
            The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask`
            in `forward()`.
        batch_size: `int`
        num_segment_concat_wordpieces: `int`
            The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e.
            the original `token_ids.size(1)`.

        # Returns:

        embeddings: `torch.FloatTensor`
            Shape: [batch_size, self._num_wordpieces, embedding_size].
        """
        def lengths_to_mask(lengths, max_len, device):
            return torch.arange(max_len, device=device).expand(
                lengths.size(0), max_len) < lengths.unsqueeze(1)

        device = embeddings.device
        num_segments = int(embeddings.size(0) / batch_size)
        embedding_size = embeddings.size(2)

        # We want to remove all segment-level special tokens but maintain sequence-level ones
        num_wordpieces = num_segment_concat_wordpieces - (
            num_segments - 1) * self._num_added_tokens

        embeddings = embeddings.reshape(batch_size,
                                        num_segments * self._max_length,
                                        embedding_size)
        mask = mask.reshape(batch_size, num_segments * self._max_length)
        # We assume that all 1s in the mask preceed all 0s, and add an assert for that.
        # Open an issue on GitHub if this breaks for you.
        # Shape: (batch_size,)
        seq_lengths = mask.sum(-1)
        if not (lengths_to_mask(seq_lengths, mask.size(1), device)
                == mask).all():
            raise ValueError(
                "Long sequence splitting only supports masks with all 1s preceding all 0s."
            )
        # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op
        end_token_indices = (
            seq_lengths.unsqueeze(-1) -
            torch.arange(self._num_added_end_tokens, device=device) - 1)

        # Shape: (batch_size, self._num_added_start_tokens, embedding_size)
        start_token_embeddings = embeddings[:, :self.
                                            _num_added_start_tokens, :]
        # Shape: (batch_size, self._num_added_end_tokens, embedding_size)
        end_token_embeddings = batched_index_select(embeddings,
                                                    end_token_indices)

        embeddings = embeddings.reshape(batch_size, num_segments,
                                        self._max_length, embedding_size)
        embeddings = embeddings[:, :, self._num_added_start_tokens:-self.
                                _num_added_end_tokens, :]  # truncate segment-level start/end tokens
        embeddings = embeddings.reshape(batch_size, -1,
                                        embedding_size)  # flatten

        # Now try to put end token embeddings back which is a little tricky.

        # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation.
        # Shape: (batch_size,)
        num_effective_segments = (seq_lengths + self._max_length -
                                  1) / self._max_length
        # The number of indices that end tokens should shift back.
        num_removed_non_end_tokens = (
            num_effective_segments * self._num_added_tokens -
            self._num_added_end_tokens)
        # Shape: (batch_size, self._num_added_end_tokens)
        end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1)
        assert (end_token_indices >= self._num_added_start_tokens).all()
        # Add space for end embeddings
        embeddings = torch.cat(
            [embeddings, torch.zeros_like(end_token_embeddings)], 1)
        # Add end token embeddings back
        embeddings.scatter_(
            1,
            end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings),
            end_token_embeddings)

        # Now put back start tokens. We can do this before putting back end tokens, but then
        # we need to change `num_removed_non_end_tokens` a little.
        embeddings = torch.cat([start_token_embeddings, embeddings], 1)

        # Truncate to original length
        embeddings = embeddings[:, :num_wordpieces, :]
        return embeddings
コード例 #26
0
    def forward(self,
                inputs: torch.Tensor,
                mask: torch.BoolTensor = None) -> torch.FloatTensor:
        """
        # Parameters

        inputs : `torch.FloatTensor`, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : `torch.BoolTensor`, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        # Returns

        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps).bool()

        # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
        combined_projection = self._combined_projection(inputs)
        # split by attention dim - if values_dim > attention_dim, we will get more
        # than 3 elements returned. All of the rest are the values vector, so we
        # just concatenate them back together again below.
        queries, keys, *values = combined_projection.split(
            self._attention_dim, -1)
        queries = queries.contiguous()
        keys = keys.contiguous()
        values = torch.cat(values, -1).contiguous()
        # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
        values_per_head = values.view(batch_size, timesteps, num_heads,
                                      int(self._values_dim / num_heads))
        values_per_head = values_per_head.transpose(1, 2).contiguous()
        values_per_head = values_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._values_dim / num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        queries_per_head = queries.view(batch_size, timesteps, num_heads,
                                        int(self._attention_dim / num_heads))
        queries_per_head = queries_per_head.transpose(1, 2).contiguous()
        queries_per_head = queries_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._attention_dim / num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        keys_per_head = keys.view(batch_size, timesteps, num_heads,
                                  int(self._attention_dim / num_heads))
        keys_per_head = keys_per_head.transpose(1, 2).contiguous()
        keys_per_head = keys_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._attention_dim / num_heads))

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(queries_per_head / self._scale,
                                        keys_per_head.transpose(1, 2))

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = masked_softmax(
            scaled_similarities,
            mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps),
            memory_efficient=True,
        )
        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps,
                               int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
 def forward(self, tokens: torch.Tensor, mask: torch.BoolTensor):
     if mask is not None:
         tokens = tokens * mask.unsqueeze(-1)
     output = self.pool(self._module(tokens))
     return output
コード例 #28
0
def episode():
    done = False
    rewardsSum = 0
    qSum = 0
    lossSum = 0

    env.reset()

    last_screen = resize(get_screen(env)).unsqueeze(0).to(device)
    current_screen = resize(get_screen(env)).unsqueeze(0).to(device)
    state = current_screen - last_screen
    maxHeight = -100

    while not done:

        action, q = agent.selectAction(state)
        if q > -1:
            qSum += q

        obs, reward, done, _ = env.step(action)

        if obs[0] >= 0.5:
            reward += 10

        maxHeight = max(obs[0], maxHeight)

        env.render()

        last_screen = current_screen
        current_screen = resize(get_screen(env)).unsqueeze(0).to(device)
        if not done:
            nextState = current_screen - last_screen
        else:
            nextState = None

        rewardT = Variable(FloatTensor([reward])).to(device)
        doneT = Variable(BoolTensor([done])).to(device)

        rewardsSum = np.add(rewardsSum, reward)
        
        loss = agent.trainDQN()
        lossSum += loss
        agent.addMemory((state, action, rewardT, nextState, doneT), loss)

        state = nextState

    avgQ = qSum/200

    episode_score.append(rewardsSum)
    episode_qs.append(avgQ)
    episode_height.append(maxHeight)
    episode_loss.append(lossSum)
    episode_decay.append(agent.eps_threshold)
    plot_episode()

    #if ep % agent.sync == 0:
    #   agent.targetNetwork.load_state_dict(agent.trainNetwork.state_dict())

    if ep % 150 == 0:
        agent.save()

    print("now epsilon is {}, the reward is {} with loss {} in episode {}".format(
        agent.epsilon, rewardsSum, lossSum, ep))