Esempio n. 1
0
    def _get_mask_for_eval(self,
                           mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask
    def _get_mask_for_eval(self,
                           mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask
Esempio n. 3
0
    def _rag_sequence_loss(cls, target: torch.LongTensor, scores: torch.Tensor,
                           null_idx: int) -> torch.Tensor:
        """
        RAG Sequence loss.

        :param target:
            target tokens
        :param scores:
            model log probs
        :param null_idx:
            padding index to ignore for loss computation

        :return loss:
            return NLL Loss
        """
        ll = scores.gather(dim=-1, index=target)
        pad_mask = target.eq(null_idx)
        if pad_mask.any():
            ll.masked_fill_(pad_mask, 0.0)
        ll = ll.squeeze(-1)
        ll = ll.sum(2)  # sum over tokens
        ll = ll.logsumexp(1)  # sum over docs
        nll_loss = -ll
        loss = nll_loss

        return loss
Esempio n. 4
0
 def _get_unknown_tag_mask(self, mask: torch.LongTensor,
                           head_tags: torch.LongTensor) -> torch.LongTensor:
     oov = self.vocab.get_token_index(DEFAULT_OOV_TOKEN, 'head_tags')
     new_mask = mask.detach()
     oov_mask = head_tags.eq(oov).long()
     new_mask = new_mask * (1 - oov_mask)
     return new_mask
Esempio n. 5
0
def average_pooling(encoded_layers: torch.FloatTensor,
                    token_subword_index: torch.LongTensor) -> torch.Tensor:
    batch_size, num_tokens, num_subwords = token_subword_index.size()
    batch_index = torch.arange(batch_size).view(-1, 1,
                                                1).type_as(token_subword_index)
    token_index = torch.arange(num_tokens).view(1, -1,
                                                1).type_as(token_subword_index)
    _, num_total_subwords, hidden_size = encoded_layers.size()
    expanded_encoded_layers = encoded_layers.unsqueeze(1).expand(
        batch_size, num_tokens, num_total_subwords, hidden_size)
    # [batch_size, num_tokens, num_subwords, hidden_size]
    token_reprs = expanded_encoded_layers[batch_index, token_index,
                                          token_subword_index]
    subword_pad_mask = token_subword_index.eq(0).unsqueeze(3).expand(
        batch_size, num_tokens, num_subwords, hidden_size)
    token_reprs.masked_fill_(subword_pad_mask, 0)
    # [batch_size, num_tokens, hidden_size]
    sum_token_reprs = torch.sum(token_reprs, dim=2)
    # [batch_size, num_tokens]
    num_valid_subwords = token_subword_index.ne(0).sum(dim=2)
    pad_mask = num_valid_subwords.eq(0).long()
    # Add ones to arrays where there is no valid subword.
    divisor = (num_valid_subwords +
               pad_mask).unsqueeze(2).type_as(sum_token_reprs)
    # [batch_size, num_tokens, hidden_size]
    avg_token_reprs = sum_token_reprs / divisor
    return avg_token_reprs
Esempio n. 6
0
    def __call__(self,
                 predictions: torch.LongTensor,  # SHAPE: (batch_size, seq_len)
                 labels: torch.LongTensor,  # SHAPE: (batch_size, seq_len)
                 mask: torch.LongTensor  # SHAPE: (batch_size, seq_len)
                 ):
        if len(predictions.size()) != 2:
            raise Exception('inputs should have two dimensions')
        self._used = True

        # compute three values for each category
        for cate in self._label_cate:
            for label in cate:
                label = self._vocab.get_token_index(label, self._namespace)
                self._count[cate]['gold'] += \
                    (mask * labels.eq(label).long()).sum().item()
                self._count[cate]['predict'] += \
                    (mask * predictions.eq(label).long()).sum().item()
                self._count[cate]['match'] += \
                    (mask * labels.eq(label).long() * labels.eq(predictions).long()).sum().item()
Esempio n. 7
0
    def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)
        features = self.extract_features(tokens.to(device=self.device))
        sentence_representation = features[
            tokens.eq(self.task.source_dictionary.eos()), :
        ].view(features.size(0), -1, features.size(-1))[:, -1, :]

        logits = self.model.classification_heads[head](sentence_representation)
        if return_logits:
            return logits
        return F.log_softmax(logits, dim=-1)
Esempio n. 8
0
    def get_retrieval_indices(
        self, ret_vec: torch.LongTensor, ret_type: RetrievalType
    ) -> torch.LongTensor:
        """
        Return the batch indices for the given retrieval type.

        This function is extremely overloaded to handle all `KnowledgeAccessMethod`s and
        all `RetrievalType`s.

        Basically, if BB2's Access Method is not CLASSIFY, we return all indices
        if the specified retrieval type matches the access method. Otherwise, we look
        at the retrieval_type vector to find the corresponding indices.

        :param ret_vec:
            the retrieval_type vector indicating the "classified" retrieval type
            for each batch item
        :param ret_type:
            the retrieval type being considered here.

        :return indices:
            return which batch indices will utilize the given RetrievalType
        """
        no_indices = torch.zeros(0).long()
        all_indices = torch.arange(ret_vec.size(0)).long()
        type_indices = ret_vec.eq(ret_type.value).nonzero().squeeze(1).long()
        assert isinstance(all_indices, torch.LongTensor)
        assert isinstance(no_indices, torch.LongTensor)
        assert isinstance(type_indices, torch.LongTensor)

        if self.knowledge_access_method is KnowledgeAccessMethod.NONE:
            if ret_type is RetrievalType.NONE:
                return all_indices
            else:
                return no_indices
        elif self.knowledge_access_method is KnowledgeAccessMethod.ALL:
            if ret_type is RetrievalType.NONE:
                return no_indices
            else:
                return all_indices
        elif self.knowledge_access_method is KnowledgeAccessMethod.SEARCH_ONLY:
            if ret_type is RetrievalType.SEARCH:
                return all_indices
            else:
                return no_indices
        elif self.knowledge_access_method is KnowledgeAccessMethod.MEMORY_ONLY:
            if ret_type is RetrievalType.MEMORY:
                return all_indices
            else:
                return no_indices
        else:
            assert self.knowledge_access_method is KnowledgeAccessMethod.CLASSIFY
            return type_indices
Esempio n. 9
0
    def forward(
        self,
        feature: torch.FloatTensor,  # SHAPE: (num_nodes, feat_size)
        adj: torch.LongTensor,  # SHAPE: (2, num_edges)
        label: torch.LongTensor,  # SHAPE: (num_nodes,)
        train_mask: torch.LongTensor,  # SHAPE: (num_nodes,)
        dev_mask: torch.LongTensor,  # SHAPE: (num_nodes,)
        test_mask: torch.LongTensor):  # SHAPE: (num_nodes,)
        if self.method == 'gcn1' or self.method == 'gcn1_diag':
            feature = self.gcn1(feature, adj)
        elif self.method == 'gat1':
            feature = self.gat1(feature, adj)
        pred_repr = self.ff(feature)
        if self.method == 'gcn2':
            pred_repr = self.gcn2(pred_repr, adj)

        # SHAPE: (num_nodes, num_class)
        logits = self.pred_ff(pred_repr)
        loss = nn.CrossEntropyLoss(reduction='none')(logits, label)

        train_mask = train_mask.eq(1)
        dev_mask = dev_mask.eq(1)
        test_mask = test_mask.eq(1)

        train_logits = torch.masked_select(logits,
                                           train_mask.unsqueeze(-1)).view(
                                               -1, self.num_class)
        dev_logits = torch.masked_select(logits, dev_mask.unsqueeze(-1)).view(
            -1, self.num_class)
        test_logits = torch.masked_select(logits,
                                          test_mask.unsqueeze(-1)).view(
                                              -1, self.num_class)

        train_loss = torch.masked_select(loss, train_mask).mean()
        dev_loss = torch.masked_select(loss, dev_mask).mean()
        test_loss = torch.masked_select(loss, test_mask).mean()

        return train_logits, dev_logits, test_logits, train_loss, dev_loss, test_loss
Esempio n. 10
0
    def generate(
        self,
        source: torch.LongTensor,
        limit: int,
        bos_token: int = 1,
        eos_token: int = 2,
    ) -> torch.LongTensor:
        """Generate predicted sequence.

        Args:
            source (torch.LongTensor): source tensor of shape `batch_size, seq_len`.
            limit (int): generation limit.
            bos_token (int): begin of sequence token. Defaults to 1.
            eos_token (int): end of sequence token. Defaults to 2.

        Returns:
            torch.LongTensor: generated tensor of shape `batch_size, gen_len`.
        """
        assert limit <= self.max_seq_len

        batch_size, _ = source.shape
        device = source.device

        encoder_emb = self.encoder_embeddings(source)
        pad_mask = source.eq(0)
        encoder_mask = pad_mask.unsqueeze(1).repeat(1, pad_mask.size(-1), 1)
        encoder_repr = self.encoder(encoder_emb, encoder_mask)

        prediction = torch.full(
            (batch_size, 1), bos_token, dtype=torch.long, device=device
        )
        generated = torch.full_like(prediction, 0).bool()
        cache = None

        for i in range(1, limit):
            decoder_emb = self.decoder_embeddings(prediction).narrow(1, -1, 1)
            decoder_repr, cache = self.decoder(decoder_emb, encoder_repr, cache=cache)
            distribution = self.out_to_vocab(decoder_repr).argmax(-1)
            prediction = torch.cat((prediction, distribution), dim=1)

            generated = generated | (distribution == eos_token)
            if torch.all(generated):
                break

        return prediction
Esempio n. 11
0
    def forward(self,
                embedded_tokens: torch.FloatTensor,
                input_mask: torch.LongTensor,
                segment_ids: torch.LongTensor = None):  # pylint: disable=arguments-differ
        embedded_tokens = embedded_tokens * self.embed_scale
        embedded_tokens = common_attention.embedding_postprocessor(
            embedded_tokens,
            input_mask.long(),
            self._use_fp16,
            token_type_ids=segment_ids,
            use_token_type=self._use_token_type,
            token_type_embedding=self._token_type_embedding,
            use_position_embeddings=self._use_position_embeddings,
            position_embedding=self._position_embedding,
            norm_layer=self._norm_layer,
            dropout=self._dropout)
        encoder_self_attention_bias = common_attention.create_attention_mask_from_input_mask(
            embedded_tokens, input_mask, self._use_fp16)

        encoder_padding_mask = input_mask.eq(0)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None
        prev_output = embedded_tokens
        for (attention, feedforward_output, feedforward,
             feedforward_intemediate, layer_norm_output, layer_norm) in zip(
                 self._attention_layers, self._feedforward_output_layers,
                 self._feedforward_layers,
                 self._feedforward_intermediate_layers,
                 self._layer_norm_output_layers, self._layer_norm_layers):
            layer_input = prev_output
            attention_output = attention(layer_input,
                                         encoder_self_attention_bias,
                                         key_padding_mask=encoder_padding_mask)
            attention_output = self._dropout(
                feedforward_output(attention_output))
            attention_output = layer_norm_output(attention_output +
                                                 layer_input)
            attention_intermediate = self._activation(
                feedforward_intemediate(attention_output))
            layer_output = self._dropout(feedforward(attention_intermediate))
            layer_output = layer_norm(layer_output + attention_output)
            prev_output = layer_output

        return prev_output
Esempio n. 12
0
def max_pooling(encoded_layers: torch.FloatTensor,
                token_subword_index: torch.LongTensor) -> torch.Tensor:
    batch_size, num_tokens, num_subwords = token_subword_index.size()
    batch_index = torch.arange(batch_size).view(-1, 1,
                                                1).type_as(token_subword_index)
    token_index = torch.arange(num_tokens).view(1, -1,
                                                1).type_as(token_subword_index)
    _, num_total_subwords, hidden_size = encoded_layers.size()
    expanded_encoded_layers = encoded_layers.unsqueeze(1).expand(
        batch_size, num_tokens, num_total_subwords, hidden_size)
    # [batch_size, num_tokens, num_subwords, hidden_size]
    token_reprs = expanded_encoded_layers[batch_index, token_index,
                                          token_subword_index]
    subword_pad_mask = token_subword_index.eq(0).unsqueeze(3).expand(
        batch_size, num_tokens, num_subwords, hidden_size)
    token_reprs.masked_fill_(subword_pad_mask, -float('inf'))
    # [batch_size, num_tokens, hidden_size]
    max_token_reprs, _ = torch.max(token_reprs, dim=2)
    # [batch_size, num_tokens]
    num_valid_subwords = token_subword_index.ne(0).sum(dim=2)
    pad_mask = num_valid_subwords.eq(0).unsqueeze(2).expand(
        batch_size, num_tokens, hidden_size)
    max_token_reprs.masked_fill(pad_mask, 0)
    return max_token_reprs
Esempio n. 13
0
 def get_segment_mask(class_map: LongTensor, index: int) -> Tensor:
     return class_map.eq(index).float()
Esempio n. 14
0
    def __call__(self,
                 predictions: torch.LongTensor,  # SHAPE: (batch_size, seq_len)
                 labels: torch.LongTensor,  # SHAPE: (batch_size, seq_len)
                 mask: torch.LongTensor,  # SHAPE: (batch_size, seq_len)
                 recall: torch.LongTensor = None,  # SHAPE: (batch_size, seq_len)
                 duplicate_check: bool = True,
                 bucket_value: torch.LongTensor = None,  # SHPAE: (batch_size, seq_len)
                 sig_test: bool = False,
                 mask_oos: torch.LongTensor = None,  # SHAPE: (batch_size, seq_len)
                 ):  # SHAPE: (batch_size)
        if len(predictions.size()) != 2:
            raise Exception('inputs should have two dimensions')
        self._used = True
        predicted = (predictions.ne(self._neg_label).long() * mask).float()
        if mask_oos is None:
            whole_subset = (labels.ne(self._neg_label).long() * mask).float()
        else:
            whole_subset = (labels.ne(self._neg_label).long() * (mask + mask_oos).ne(0).long()).float()
        self._recall_local += whole_subset.sum().item()
        if recall is not None:
            whole = recall.float()
            if duplicate_check:
                assert whole_subset.sum().item() <= whole.sum().item(), 'found duplicate span pairs'
        else:
            whole = whole_subset

        if self._binary_match:
            matched = (predictions.ne(self._neg_label).long() * labels.ne(self._neg_label).long() * mask).float()
        else:
            matched = (predictions.eq(labels).long() * mask * predictions.ne(self._neg_label).long()).float()

        if self._reduce == 'micro':
            self._count += matched.sum().item()
            self._precision += predicted.sum().item()
            self._recall += whole.sum().item()
            self._num_sample += predictions.size(0)
            if sig_test:
                for i in range(predictions.size(0)):
                    print('ST\t{}\t{}\t{}'.format(int(matched[i].sum().item()),
                                                  int(predicted[i].sum().item()),
                                                  int(whole[i].sum().item())))
            if bucket_value is not None:
                bucket_value = (bucket_value * mask).cpu().numpy().reshape(-1)
                matched = matched.cpu().numpy().reshape(-1)
                predicted = predicted.cpu().numpy().reshape(-1)
                whole_subset = whole_subset.cpu().numpy().reshape(-1)
                count = np.bincount(bucket_value, matched)
                precision = np.bincount(bucket_value, predicted)
                recall = np.bincount(bucket_value, whole_subset)
                for name in ['count', 'precision', 'recall']:
                    value = eval(name)
                    for b, v in zip(range(len(value)), value):
                        self._bucket[name][b] += v

        elif self._reduce == 'macro':
            # TODO: the implementation is problematic because samples without label/prediction
            #   are counted as zero recall/precision
            self._count += matched.size(0)
            pre = matched / (predicted + 1e-10)
            rec = matched / (whole + 1e-10)
            f1 = 2 * pre * rec / (pre + rec + 1e-10)
            self._precision += pre.sum().item()
            self._recall += rec.sum().item()
            self._f1 += f1.sum().item()
            self._num_sample += predictions.size(0)
Esempio n. 15
0
    def forward(
        self,
        src_tokens: torch.LongTensor,
        src_lengths: torch.LongTensor,
        return_encoder_out: bool = False,
        return_encoder_padding_mask: bool = False,
    ) -> EncoderOuts:
        """Encode a batch of sequences

        Arguments:
            src_tokens {torch.LongTensor} -- [batch_size, seq_len]
            src_lengths {torch.LongTensor} -- [batch_size]

        Keyword Arguments:
            return_encoder_out {bool} --
                Return output tensors? (default: {False})
            return_encoder_padding_mask {bool} --
                Return encoder padding mask? (default: {False})

        Returns:
            [type] -- [description]
        """
        bsz, seqlen = src_tokens.size()

        x = self.embed_tokens(src_tokens)
        x = x.transpose(0, 1)  # BTC -> TBC

        # Pack then apply LSTM
        packed_x = nn.utils.rnn.pack_padded_sequence(x,
                                                     src_lengths,
                                                     batch_first=False,
                                                     enforce_sorted=True)
        packed_outs, (final_hiddens, final_cells) = \
            self.lstm.forward(packed_x)

        x, _ = nn.utils.rnn.pad_packed_sequence(
            packed_outs, padding_value=self.padding_value)
        assert list(x.size()) == [seqlen, bsz, self.output_units]

        # Set padded outputs to -inf so they are not selected by max-pooling
        padding_mask = src_tokens.eq(self.padding_idx).t()
        if padding_mask.any():
            x = x.float().masked_fill_(
                mask=padding_mask.unsqueeze(-1),
                value=float('-inf'),
            ).type_as(x)

        # Build the sentence embedding by max-pooling over the encoder outputs
        sentemb = x.max(dim=0)[0]

        encoder_out = None
        if return_encoder_out:
            final_hiddens = self._combine_outs(final_hiddens)
            final_cells = self._combine_outs(final_cells)
            encoder_out = (x, final_hiddens, final_cells)

        encoder_padding_mask = None
        if return_encoder_padding_mask:
            encoder_padding_mask = src_tokens.eq(self.padding_idx).t()

        return EncoderOuts(sentemb=sentemb,
                           encoder_out=encoder_out,
                           encoder_padding_mask=encoder_padding_mask)