Ejemplo n.º 1
0
    def extract_features(
        self, tokens: torch.LongTensor, return_all_hiddens: bool = False
    ) -> torch.Tensor:
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)
        if tokens.size(-1) > min(self.model.max_positions()):
            raise ValueError(
                "tokens exceeds maximum length: {} > {}".format(
                    tokens.size(-1), self.model.max_positions()
                )
            )
        tokens.to(device=self.device),
        prev_output_tokens = tokens.clone()

        prev_output_tokens[:, 0] = tokens.gather(
            1,
            (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1),
        ).squeeze()

        prev_output_tokens[:, 1:] = tokens[:, :-1]
        features, extra = self.model(
            src_tokens=tokens,
            src_lengths=None,
            prev_output_tokens=prev_output_tokens,
            features_only=True,
            return_all_hiddens=return_all_hiddens,
        )
        if return_all_hiddens:
            # convert from T x B x C -> B x T x C
            inner_states = extra["inner_states"]
            return [inner_state.transpose(0, 1) for inner_state in inner_states]
        else:
            return features  # just the last layer's features
Ejemplo n.º 2
0
def average_pooling(encoded_layers: torch.FloatTensor,
                    token_subword_index: torch.LongTensor) -> torch.Tensor:
    batch_size, num_tokens, num_subwords = token_subword_index.size()
    batch_index = torch.arange(batch_size).view(-1, 1,
                                                1).type_as(token_subword_index)
    token_index = torch.arange(num_tokens).view(1, -1,
                                                1).type_as(token_subword_index)
    _, num_total_subwords, hidden_size = encoded_layers.size()
    expanded_encoded_layers = encoded_layers.unsqueeze(1).expand(
        batch_size, num_tokens, num_total_subwords, hidden_size)
    # [batch_size, num_tokens, num_subwords, hidden_size]
    token_reprs = expanded_encoded_layers[batch_index, token_index,
                                          token_subword_index]
    subword_pad_mask = token_subword_index.eq(0).unsqueeze(3).expand(
        batch_size, num_tokens, num_subwords, hidden_size)
    token_reprs.masked_fill_(subword_pad_mask, 0)
    # [batch_size, num_tokens, hidden_size]
    sum_token_reprs = torch.sum(token_reprs, dim=2)
    # [batch_size, num_tokens]
    num_valid_subwords = token_subword_index.ne(0).sum(dim=2)
    pad_mask = num_valid_subwords.eq(0).long()
    # Add ones to arrays where there is no valid subword.
    divisor = (num_valid_subwords +
               pad_mask).unsqueeze(2).type_as(sum_token_reprs)
    # [batch_size, num_tokens, hidden_size]
    avg_token_reprs = sum_token_reprs / divisor
    return avg_token_reprs
Ejemplo n.º 3
0
    def compute_loss(
        self,
        criterion: torch.nn.Module,
        scores: torch.Tensor,
        preds: torch.LongTensor,
        enc_state: Tuple[Any],
        label_vec: torch.LongTensor,
    ) -> Tuple[torch.Tensor, List[int], torch.Tensor, torch.Tensor]:
        """
        Compute RAG Sequence Loss.

        :param criterion:
            torch criterion module
        :param scores:
            model scores
        :param preds:
            model "predicions" of tokens
        :param enc_state:
            encoder states
        :param label_vec:
            target tokens

        :return (loss, metric_loss, correct_tokens, target_tokens):
            loss: the loss through which we backpropagate
            metric_loss: loss we use for metrics
            correct_tokens: correct predictions from the model
            target_tokens: the ground truth tokens.
        """
        if scores.size(2) != label_vec.size(1):
            assert self.generation_model == 'bart'
            # ignore start
            scores = scores[:, :, 1:, :]
            preds = preds[:, 1:]  # type: ignore

        # compute rag sequence loss
        seq_preds = scores.max(-1)[-1]
        bsz = scores.size(0)
        n_docs = scores.size(1)
        target = label_vec.unsqueeze(1).unsqueeze(-1).repeat(1, n_docs, 1, 1)
        loss = self._rag_sequence_loss(target, scores,
                                       self.null_idx)  # type: ignore

        # compute relevant metric counters
        metric_loss = loss.tolist()
        notnull = label_vec.ne(self.null_idx)
        target_tokens = notnull.long().sum(dim=-1)

        shape = (bsz, n_docs, -1)
        notnull_seq = target.ne(self.null_idx)
        correct = ((target.view(*shape) == seq_preds.view(*shape)) *
                   notnull_seq.view(*shape)).sum(dim=-1)
        correct_mean = correct.float().mean(dim=1)

        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        return loss, metric_loss, correct_mean, target_tokens
Ejemplo n.º 4
0
    def compute_loss(
        self,
        criterion: torch.nn.Module,
        scores: torch.Tensor,
        preds: torch.LongTensor,
        enc_state: Tuple[Any],
        label_vec: torch.LongTensor,
    ) -> Tuple[torch.Tensor, List[int], torch.Tensor, torch.Tensor]:
        """
        Compute RAG Token Loss.

        This is a simple NLL Loss.

        :param criterion:
            presumably the NLL criterion.
        :param scores:
            model scores
        :param preds:
            model "predicions" of tokens
        :param enc_state:
            encoder states
        :param label_vec:
            target tokens

        :return (loss, metric_loss, correct_tokens, target_tokens):
            loss: the loss through which we backpropagate
            metric_loss: loss we use for metrics
            correct_tokens: correct predictions from the model
            target_tokens: the ground truth tokens.
        """
        if scores.size(1) != label_vec.size(1):
            assert self.generation_model == 'bart'
            # ignore start
            scores = scores[:, 1:, :]
            preds = preds[:, 1:]  # type: ignore

        # compute loss
        score_view = scores.reshape(-1, scores.size(-1))
        loss = criterion(score_view, label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)

        # calculate metric counters
        metric_loss = loss.tolist()
        notnull = label_vec.ne(self.null_idx)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((label_vec == preds) * notnull).sum(dim=-1)

        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        return loss, metric_loss, correct, target_tokens
Ejemplo n.º 5
0
def max_pooling(encoded_layers: torch.FloatTensor,
                token_subword_index: torch.LongTensor) -> torch.Tensor:
    batch_size, num_tokens, num_subwords = token_subword_index.size()
    batch_index = torch.arange(batch_size).view(-1, 1,
                                                1).type_as(token_subword_index)
    token_index = torch.arange(num_tokens).view(1, -1,
                                                1).type_as(token_subword_index)
    _, num_total_subwords, hidden_size = encoded_layers.size()
    expanded_encoded_layers = encoded_layers.unsqueeze(1).expand(
        batch_size, num_tokens, num_total_subwords, hidden_size)
    # [batch_size, num_tokens, num_subwords, hidden_size]
    token_reprs = expanded_encoded_layers[batch_index, token_index,
                                          token_subword_index]
    subword_pad_mask = token_subword_index.eq(0).unsqueeze(3).expand(
        batch_size, num_tokens, num_subwords, hidden_size)
    token_reprs.masked_fill_(subword_pad_mask, -float('inf'))
    # [batch_size, num_tokens, hidden_size]
    max_token_reprs, _ = torch.max(token_reprs, dim=2)
    # [batch_size, num_tokens]
    num_valid_subwords = token_subword_index.ne(0).sum(dim=2)
    pad_mask = num_valid_subwords.eq(0).unsqueeze(2).expand(
        batch_size, num_tokens, hidden_size)
    max_token_reprs.masked_fill(pad_mask, 0)
    return max_token_reprs
Ejemplo n.º 6
0
 def forward(self, input: torch.LongTensor, *args, **kwargs):
     return input, input.ne(self.pad_idx)
Ejemplo n.º 7
0
    def compute_loss(
        self,
        criterion: torch.nn.Module,
        scores: torch.Tensor,
        preds: torch.LongTensor,
        enc_state: Tuple[Any, ...],
        label_vec: torch.LongTensor,
    ) -> Tuple[torch.Tensor, List[int], torch.Tensor, torch.Tensor]:
        """
        Compute Loss for Rag Turn.

        RAG Turn Doc-Then-Turn computes loss with a normal NLL Loss
        (everything is marginalized beforehand)

        RAG Turn Doc-Only computes loss for each input turn; this loss can be
        weighted with a discount factor, applying less weight to prior turns (only for
        backpropagation purposes).

        :param criterion:
            torch criterion module
        :param scores:
            model scores
        :param preds:
            model "predicions" of tokens
        :param enc_state:
            encoder states
        :param label_vec:
            target tokens

        :return (loss, metric_loss, correct_tokens, target_tokens):
            loss: the loss through which we backpropagate
            metric_loss: loss we use for metrics
            correct_tokens: correct predictions from the model
            target_tokens: the ground truth tokens.
        """
        if scores.size(1) != label_vec.size(1):
            # ignore start
            scores = scores[:, 1:, :]
            preds = preds[:, 1:]  # type: ignore

        input_turns_cnt = enc_state[2]
        real_bsz = label_vec.size(0)
        resize_label = real_bsz != scores.size(0)
        if resize_label:
            assert self.turn_marginalize == 'doc_only'
            label_vec = label_vec.repeat_interleave(input_turns_cnt,
                                                    dim=0)  # type: ignore

        # compute loss
        score_view = scores.reshape(-1, scores.size(-1))
        loss = criterion(score_view, label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        metric_loss = loss.tolist()

        if resize_label:
            assert self.turn_marginalize == 'doc_only'
            loss = sum_across_turns(loss,
                                    input_turns_cnt,
                                    discount=self.discount_factor)
            metric_loss = sum_across_turns(loss, input_turns_cnt).tolist()

        # compute metric counters
        notnull = label_vec.ne(self.null_idx)
        target_tokens = metric_target_tokens = notnull.long().sum(dim=-1)
        correct = metric_correct = ((label_vec == preds) * notnull).sum(dim=-1)
        if resize_label:
            metric_target_tokens = sum_across_turns(target_tokens,
                                                    input_turns_cnt)
            metric_correct = sum_across_turns(correct, input_turns_cnt)

        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        return loss, metric_loss, metric_correct, metric_target_tokens
Ejemplo n.º 8
0
    def __call__(self,
                 predictions: torch.LongTensor,  # SHAPE: (batch_size, seq_len)
                 labels: torch.LongTensor,  # SHAPE: (batch_size, seq_len)
                 mask: torch.LongTensor,  # SHAPE: (batch_size, seq_len)
                 recall: torch.LongTensor = None,  # SHAPE: (batch_size, seq_len)
                 duplicate_check: bool = True,
                 bucket_value: torch.LongTensor = None,  # SHPAE: (batch_size, seq_len)
                 sig_test: bool = False,
                 mask_oos: torch.LongTensor = None,  # SHAPE: (batch_size, seq_len)
                 ):  # SHAPE: (batch_size)
        if len(predictions.size()) != 2:
            raise Exception('inputs should have two dimensions')
        self._used = True
        predicted = (predictions.ne(self._neg_label).long() * mask).float()
        if mask_oos is None:
            whole_subset = (labels.ne(self._neg_label).long() * mask).float()
        else:
            whole_subset = (labels.ne(self._neg_label).long() * (mask + mask_oos).ne(0).long()).float()
        self._recall_local += whole_subset.sum().item()
        if recall is not None:
            whole = recall.float()
            if duplicate_check:
                assert whole_subset.sum().item() <= whole.sum().item(), 'found duplicate span pairs'
        else:
            whole = whole_subset

        if self._binary_match:
            matched = (predictions.ne(self._neg_label).long() * labels.ne(self._neg_label).long() * mask).float()
        else:
            matched = (predictions.eq(labels).long() * mask * predictions.ne(self._neg_label).long()).float()

        if self._reduce == 'micro':
            self._count += matched.sum().item()
            self._precision += predicted.sum().item()
            self._recall += whole.sum().item()
            self._num_sample += predictions.size(0)
            if sig_test:
                for i in range(predictions.size(0)):
                    print('ST\t{}\t{}\t{}'.format(int(matched[i].sum().item()),
                                                  int(predicted[i].sum().item()),
                                                  int(whole[i].sum().item())))
            if bucket_value is not None:
                bucket_value = (bucket_value * mask).cpu().numpy().reshape(-1)
                matched = matched.cpu().numpy().reshape(-1)
                predicted = predicted.cpu().numpy().reshape(-1)
                whole_subset = whole_subset.cpu().numpy().reshape(-1)
                count = np.bincount(bucket_value, matched)
                precision = np.bincount(bucket_value, predicted)
                recall = np.bincount(bucket_value, whole_subset)
                for name in ['count', 'precision', 'recall']:
                    value = eval(name)
                    for b, v in zip(range(len(value)), value):
                        self._bucket[name][b] += v

        elif self._reduce == 'macro':
            # TODO: the implementation is problematic because samples without label/prediction
            #   are counted as zero recall/precision
            self._count += matched.size(0)
            pre = matched / (predicted + 1e-10)
            rec = matched / (whole + 1e-10)
            f1 = 2 * pre * rec / (pre + rec + 1e-10)
            self._precision += pre.sum().item()
            self._recall += rec.sum().item()
            self._f1 += f1.sum().item()
            self._num_sample += predictions.size(0)