Пример #1
0
 def __call__(self, x: Tensor, y: LongTensor) -> Tensor:
     y = y.flatten(0, -1)
     y_float = torch.zeros(x.shape[0] * x.shape[1], self.nc, device=x.device, dtype=torch.float)
     y_float.fill_(self.mass_redistribution / (self.nc-(1 + len(self.ignore_index))))
     y_float.scatter_(1, y.unsqueeze(1), 1 - self.mass_redistribution)
     mask = torch.zeros_like(y, dtype=torch.bool)
     for idx in self.ignore_index:
         mask = torch.bitwise_or(mask, y == idx)
     y_float[mask.unsqueeze(1).repeat(1, self.nc)] = 0
     return self.loss_fn(torch.log_softmax(x.view(-1, self.nc), dim=-1), y_float)
    def __call__(self, x: FloatTensor, aux: LongTensor):
        """Perform the forward computation.

        Parameters:
            x: input vector.
            aux: auxiliary variables.
        """
        outputs = torch.cat(tuple(self.module_list[d](x[:, (d, )]).unsqueeze(2)
                                  for d in range(x.shape[1])),
                            dim=2).sum(dim=2)
        result = outputs[torch.arange(len(outputs)), aux.flatten()]
        return result
Пример #3
0
    def forward(
        self,
        word_ids: TextFieldTensors,
        entity_start_positions: torch.LongTensor,
        entity_end_positions: torch.LongTensor,
        original_entity_spans: torch.LongTensor,
        doc_id: List[str],
        labels: torch.LongTensor = None,
        entity_ids: torch.LongTensor = None,
        entity_position_ids: torch.LongTensor = None,
        input_words: List[List[str]] = None,
        **kwargs,
    ):
        feature_vector = self.feature_extractor(word_ids[self.text_field_key],
                                                entity_start_positions,
                                                entity_end_positions,
                                                entity_ids,
                                                entity_position_ids)

        feature_vector = self.dropout(feature_vector)
        logits = self.classifier(feature_vector)
        prediction_logits, prediction = logits.max(dim=-1)
        output_dict = {
            "logits": logits,
            "prediction": prediction,
            "input": input_words
        }

        if labels is not None:
            output_dict["loss"] = self.criterion(logits.flatten(0, 1),
                                                 labels.flatten())
            self.span_accuracy(logits, labels, mask=(labels != -1))
            self.span_f1(prediction, labels, prediction_logits,
                         original_entity_spans, doc_id, input_words)

        return output_dict
Пример #4
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                lang_ids: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.
            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
        initial_dims = list(input_ids.shape[:-1])

        # The embedder may receive an input tensor that has a sequence length longer than can
        # be fit. In that case, we should expect the wordpiece indexer to create padded windows
        # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
        # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
        # We can then split the sequence into sub-sequences of that length, and concatenate them
        # along the batch dimension so we effectively have one huge batch of partial sentences.
        # This can then be fed into BERT without any sentence length issues. Keep in mind
        # that the memory consumption can dramatically increase for large batches with extremely
        # long sentences.
        needs_split = full_seq_len > self.max_pieces
        last_window_size = 0
        if needs_split:
            # Split the flattened list by the window size, `max_pieces`
            split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))

            # We want all sequences to be the same length, so pad the last sequence
            last_window_size = split_input_ids[-1].size(-1)
            padding_amount = self.max_pieces - last_window_size
            split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0)

            # Now combine the sequences along the batch dimension
            input_ids = torch.cat(split_input_ids, dim=0)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        lang_embedding = None
        if not lang_ids is None:
            lang_set = set(lang_ids.flatten().tolist())
            assert len(lang_set) == 1, 'not all tokens from same language in a batch'
            lang_embedding = self.language_embedder(next(iter(lang_ids)))

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(input_ids=util.combine_initial_dims(input_ids),
                                                token_type_ids=util.combine_initial_dims(token_type_ids),
                                                attention_mask=util.combine_initial_dims(input_mask),
                                                lang_embedding=lang_embedding.to(input_ids.device) if not lang_embedding is None else None)
        all_encoder_layers = torch.stack(all_encoder_layers)

        if needs_split:
            # First, unpack the output embeddings into one long sequence again
            unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1)
            unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)

            # Next, select indices of the sequence such that it will result in embeddings representing the original
            # sentence. To capture maximal context, the indices will be the middle part of each embedded window
            # sub-sequence (plus any leftover start and final edge windows), e.g.,
            #  0     1 2    3  4   5    6    7     8     9   10   11   12    13 14  15
            # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
            # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
            # and final windows with indices [0, 1] and [14, 15] respectively.

            # Find the stride as half the max pieces, ignoring the special start and end tokens
            # Calculate an offset to extract the centermost embeddings of each window
            stride = (self.max_pieces - self.start_tokens - self.end_tokens) // 2
            stride_offset = stride // 2 + self.start_tokens

            first_window = list(range(stride_offset))

            max_context_windows = [i for i in range(full_seq_len)
                                   if stride_offset - 1 < i % self.max_pieces < stride_offset + stride]

            final_window_start = full_seq_len - (full_seq_len % self.max_pieces) + stride_offset + stride
            final_window = list(range(final_window_start, full_seq_len))

            select_indices = first_window + max_context_windows + final_window

            initial_dims.append(len(select_indices))

            recombined_embeddings = unpacked_embeddings[:, :, select_indices]
        else:
            recombined_embeddings = all_encoder_layers

        # Recombine the outputs of all layers
        # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
        # recombined = torch.cat(combined, dim=2)
        input_mask = (recombined_embeddings != 0).long()

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            dims = initial_dims if needs_split else input_ids.size()
            layers = util.uncombine_initial_dims(recombined_embeddings, dims)
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(offsets2d.size(0),
                                                 device=util.get_device_of(recombined_embeddings)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = recombined_embeddings[:, range_vector, offsets2d]

            layers = util.uncombine_initial_dims(selected_embeddings, offsets.size())

        if self._scalar_mix is not None:
            output = self._scalar_mix(layers, input_mask)
        elif self.combine_layers == "last":
            output = layers[-1]
        else:
            output = layers

        return output