def sequence_cross_entropy_with_logits(
        logits: torch.FloatTensor,
        targets: torch.LongTensor,
        mask: torch.BoolTensor,
        label_smoothing: bool,
        reduce: str = "mean") -> torch.FloatTensor:
    """
    label_smoothing : ``float``, optional (default = 0.0)
        It should be smaller than 1.
    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = F.log_softmax(logits_flat, dim=-1)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    if label_smoothing > 0.0:
        num_classes = logits.size(-1)
        smoothing_value = label_smoothing / float(num_classes)
        # Fill all the correct indices with 1 - smoothing value.
        one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(
            -1, targets_flat, 1.0 - label_smoothing)
        smoothed_targets = one_hot_targets + smoothing_value
        negative_log_likelihood_flat = -log_probs_flat * smoothed_targets
        negative_log_likelihood_flat = negative_log_likelihood_flat.sum(
            -1, keepdim=True)
    else:
        # shape : (batch * sequence_length, 1)
        negative_log_likelihood_flat = -torch.gather(
            log_probs_flat, dim=1, index=targets_flat)

    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(
        -1, logits.shape[1])

    mask = mask.float()
    # shape : (batch, sequence_length)
    loss = negative_log_likelihood * mask

    if reduce == "mean":
        loss = loss.sum() / (mask.sum() + 1e-13)
    elif reduce == "batch":
        # shape : (batch,)
        loss = loss.sum(1) / (mask.sum(1) + 1e-13)
    elif reduce == "batch-sequence":
        # we favor longer sequences, so we don't divide with the total sequence length here
        # shape : (batch,)
        loss = loss.sum(1)

    return loss
    def _joint_likelihood(self, logits: torch.Tensor, tags: torch.Tensor,
                          mask: torch.BoolTensor) -> torch.Tensor:
        """
        Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
        """
        batch_size, sequence_length, _ = logits.data.shape

        # Transpose batch size and sequence dimensions:
        logits = logits.transpose(0, 1).contiguous()
        mask = mask.transpose(0, 1).contiguous()
        tags = tags.transpose(0, 1).contiguous()

        # Start with the transition scores from start_tag to the first tag in each input
        if self.include_start_end_transitions:
            score = self.start_transitions.index_select(0, tags[0])
        else:
            score = 0.0

        # Add up the scores for the observed transitions and all the inputs but the last
        for i in range(sequence_length - 1):
            # Each is shape (batch_size,)
            current_tag, next_tag = tags[i], tags[i + 1]

            # The scores for transitioning from current_tag to next_tag
            transition_score = self.transitions[current_tag.view(-1),
                                                next_tag.view(-1)]

            # The score for using current_tag
            emit_score = logits[i].gather(1, current_tag.view(batch_size,
                                                              1)).squeeze(1)

            # Include transition score if next element is unmasked,
            # input_score if this element is unmasked.
            score = score + transition_score * mask[i +
                                                    1] + emit_score * mask[i]

        # Transition from last state to "stop" state. To start with, we need to find the last tag
        # for each instance.
        last_tag_index = mask.sum(0).long() - 1
        last_tags = tags.gather(0, last_tag_index.view(1,
                                                       batch_size)).squeeze(0)

        # Compute score of transitioning to `stop_tag` from each "last tag".
        if self.include_start_end_transitions:
            last_transition_score = self.end_transitions.index_select(
                0, last_tags)
        else:
            last_transition_score = 0.0

        # Add the last input if it's not masked.
        last_inputs = logits[-1]  # (batch_size, num_tags)
        last_input_score = last_inputs.gather(1, last_tags.view(
            -1, 1))  # (batch_size, 1)
        last_input_score = last_input_score.squeeze()  # (batch_size,)

        score = score + last_transition_score + last_input_score * mask[-1]

        return score
Example #3
0
    def forward(self, sequence: Tensor, mask: BoolTensor) -> DynamicRnnOutput:
        """
        rnn 执行。特别注意: 所有的都是 batch first
        :param sequence: sequence 序列, shape: (B, seq_len, input_size)
        :param mask: 对 sequence 的 mask, shape: (B, seq_len)
        :return: 解码后的结果,具体参考 DynamicOutput 说明
        """
        assert sequence.dim() == 3, \
            f"sequence shape: {sequence.dim()} 与 (B, seq_len, input_size) 不匹配"

        assert sequence.size(-1) == self.rnn.input_size, \
            f"sequence.size(-1): {sequence.size(-1)} 与 rnn input_size: {self.rnn.input_size} 不相等"

        batch_size = sequence.size(0)
        sequence_length = sequence.size(1)

        sequence_lengths = mask.sum(dim=-1)

        pack = pack_padded_sequence(sequence,
                                    lengths=sequence_lengths,
                                    batch_first=True,
                                    enforce_sorted=False)

        packed_sequence_encoding, last_state = self.rnn(pack)

        encoding, pad_sequence_length = pad_packed_sequence(
            packed_sequence_encoding,
            batch_first=True,
            padding_value=0.0,
            total_length=sequence_length)

        if self.rnn_type == DynamicRnn.LSTM or self.rnn_type == DynamicRnn.GRU:
            h_n, c_n = last_state
        else:
            h_n = last_state
            c_n = None

        # h_n shape: (num_layers * num_directions, batch, hidden_size)
        # 因为是按照 batch first 来处理的,所以需要进行转换
        # 转换之后的 h_n shape: (batch, num_layers, hidden_size * num_directions), c_n 同样的处理
        h_n = torch.transpose(h_n, 0,
                              1).contiguous().view(batch_size, self.num_layers,
                                                   -1)
        last_layer_h_n = h_n[:, -1, :].contiguous().view(batch_size, -1)

        last_layer_c_n = None
        if c_n is not None:
            c_n = torch.transpose(c_n, 0,
                                  1).contiguous().view(batch_size,
                                                       self.num_layers, -1)
            last_layer_c_n = c_n[:, -1, :].contiguous().view(batch_size, -1)

        return DynamicRnnOutput(last_layer_h_n=last_layer_h_n,
                                last_layer_c_n=last_layer_c_n,
                                h_n=h_n,
                                c_n=c_n,
                                sequence_encoding=encoding)
Example #4
0
File: base.py Project: ipipan/combo
 def _loss(self, pred: torch.Tensor, true: torch.Tensor,
           mask: torch.BoolTensor,
           sample_weights: torch.Tensor) -> torch.Tensor:
     BATCH_SIZE, _, CLASSES = pred.size()
     valid_positions = mask.sum()
     pred = pred.reshape(-1, CLASSES)
     true = true.reshape(-1)
     mask = mask.reshape(-1)
     loss = utils.masked_cross_entropy(pred, true, mask)
     loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
     return loss.sum() / valid_positions
Example #5
0
def get_final_encoder_states(
    encoder_outputs: torch.Tensor, mask: torch.BoolTensor, bidirectional: bool = False
) -> torch.Tensor:
    last_word_indices = mask.sum(1) - 1
    batch_size, _, encoder_output_dim = encoder_outputs.size()
    expanded_indices = last_word_indices.view(-1, 1, 1).expand(batch_size, 1, encoder_output_dim)
    final_encoder_output = encoder_outputs.gather(1, expanded_indices)
    final_encoder_output = final_encoder_output.squeeze(1)  # (batch_size, encoder_output_dim)
    if bidirectional:
        final_forward_output = final_encoder_output[:, : (encoder_output_dim // 2)]
        final_backward_output = encoder_outputs[:, 0, (encoder_output_dim // 2) :]
        final_encoder_output = torch.cat([final_forward_output, final_backward_output], dim=-1)
    return final_encoder_output
Example #6
0
def get_lengths_from_binary_sequence_mask(mask: torch.BoolTensor) -> torch.LongTensor:
    """
    Compute sequence lengths for each batch element in a tensor using a
    binary mask.
    # Parameters
    mask : `torch.BoolTensor`, required.
        A 2D binary mask of shape (batch_size, sequence_length) to
        calculate the per-batch sequence lengths from.
    # Returns
    `torch.LongTensor`
        A torch.LongTensor of shape (batch_size,) representing the lengths
        of the sequences in the batch.
    """
    return mask.sum(-1)
Example #7
0
def remove_sentence_boundaries(
    tensor: torch.Tensor, mask: torch.BoolTensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    sequence_lengths = mask.sum(dim=1).detach().cpu().numpy()
    tensor_shape = list(tensor.data.shape)
    new_shape = list(tensor_shape)
    new_shape[1] = tensor_shape[1] - 2
    tensor_without_boundary_tokens = tensor.new_zeros(*new_shape)
    new_mask = tensor.new_zeros((new_shape[0], new_shape[1]), dtype=torch.bool)
    for i, j in enumerate(sequence_lengths):
        if j > 2:
            tensor_without_boundary_tokens[i, : (j - 2), :] = tensor[i, 1 : (j - 1), :]
            new_mask[i, : (j - 2)] = True

    return tensor_without_boundary_tokens, new_mask
Example #8
0
    def viterbi_decode(self, h: FloatTensor,
                       mask: BoolTensor) -> List[List[int]]:
        """
        decode labels using viterbi algorithm
        :param h: hidden matrix (batch_size, seq_len, num_labels)
        :param mask: mask tensor of each sequence
                     in mini batch (batch_size, batch_size)
        :return: labels of each sequence in mini batch
        """

        batch_size, seq_len, _ = h.size()
        # prepare the sequence lengths in each sequence
        seq_lens = mask.sum(dim=1)
        # In mini batch, prepare the score
        # from the start sequence to the first label
        score = [self.start_trans.data + h[:, 0]]
        path = []

        for t in range(1, seq_len):
            # extract the score of previous sequence
            # (batch_size, num_labels, 1)
            previous_score = score[t - 1].view(batch_size, -1, 1)

            # extract the score of hidden matrix of sequence
            # (batch_size, 1, num_labels)
            h_t = h[:, t].view(batch_size, 1, -1)

            # extract the score in transition
            # from label of t-1 sequence to label of sequence of t
            # self.trans_matrix has the score of the transition
            # from sequence A to sequence B
            # (batch_size, num_labels, num_labels)
            score_t = previous_score + self.trans_matrix + h_t

            # keep the maximum value
            # and point where maximum value of each sequence
            # (batch_size, num_labels)
            best_score, best_path = score_t.max(1)
            score.append(best_score)
            path.append(best_path)

        # predict labels of mini batch
        best_paths = [
            self._viterbi_compute_best_path(i, seq_lens, score, path)
            for i in range(batch_size)
        ]

        return best_paths
Example #9
0
    def __call__(  # type: ignore
            self, predictions: Dict[str, torch.Tensor],
            gold_labels: Dict[str, torch.Tensor], mask: torch.BoolTensor):
        self.upos_score(predictions["upostag"], gold_labels["upostag"], mask)
        self.xpos_score(predictions["xpostag"], gold_labels["xpostag"], mask)
        self.semrel_score(predictions["semrel"], gold_labels["semrel"], mask)
        self.feats_score(predictions["feats"], gold_labels["feats"], mask)
        self.lemma_score(predictions["lemma"], gold_labels["lemma"], mask)
        self.attachment_scores(predictions["head"], predictions["deprel"],
                               gold_labels["head"], gold_labels["deprel"],
                               mask)
        total = mask.sum()
        correct_indices = (self.upos_score.correct_indices *
                           self.xpos_score.correct_indices *
                           self.semrel_score.correct_indices *
                           self.feats_score.correct_indices *
                           self.lemma_score.correct_indices *
                           self.attachment_scores.correct_indices)

        total, correct_indices = self.detach_tensors(total, correct_indices)
        self.em_score = (correct_indices.float().sum() / total).item()
Example #10
0
    def _compute_numerator_log_likelihood(self, h: FloatTensor, y: LongTensor,
                                          mask: BoolTensor) -> FloatTensor:
        """
        compute the numerator term for the log-likelihood
        :param h: hidden matrix (batch_size, seq_len, num_labels)
        :param y: answer labels of each sequence
                  in mini batch (batch_size, seq_len)
        :param mask: mask tensor of each sequence
                     in mini batch (batch_size, seq_len)
        :return: The score of numerator term for the log-likelihood
        """

        batch_size, seq_len, _ = h.size()

        h_unsqueezed = h.unsqueeze(-1)
        trans = self.trans_matrix.unsqueeze(-1)

        arange_b = torch.arange(batch_size)

        # extract first vector of sequences in mini batch
        calc_range = seq_len - 1
        score = self.start_trans[y[:, 0]] + sum([
            self._calc_trans_score_for_num_llh(h_unsqueezed, y, trans, mask, t,
                                               arange_b)
            for t in range(calc_range)
        ])

        # extract end label number of each sequence in mini batch
        # (batch_size)
        last_mask_index = mask.sum(1) - 1
        last_labels = y[arange_b, last_mask_index]
        each_last_score = h[arange_b, -1, last_labels] * mask[:, -1]

        # Add the score of the sequences of the maximum length in mini batch
        # Add the scores from the last tag of each sequence to EOS
        score += each_last_score + self.end_trans[last_labels]

        return score
def compute_log_probability(
    logits: torch.FloatTensor,
    targets: torch.LongTensor,
    mask: torch.BoolTensor = None,
    debug_fxn: Callable[[object, str], None] = null_log,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
    """
    Compute sum of log probs from model logits

    Arguments:
        logits (torch.FloatTensor): Model output logits (B x T x V)
        targets (torch.LongTensor): Target tokens (B x T)
        mask (torch.BoolTensor): Mask revealing only the utterance tokens (B x T)
        debug_fxn (callable): Logging function

    Returns:
        torch.FloatTensor: Target log probabilities (B x T)
        torch.LongTensor: Number of utterance tokens (1)
    """
    # Get log probability from logits via log softmax
    log_probs = torch.log_softmax(logits, dim=-1)
    debug_fxn(log_probs, 'log_probs')
    debug_fxn(targets, 'targets')

    # Extract target token probability - (B x T)
    target_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
    debug_fxn(target_log_probs, 'target_log_probs')

    # Mask to utterance tokens
    if mask is not None:
        target_log_probs = target_log_probs.masked_select(mask)
        debug_fxn(target_log_probs, 'target_log_probs (masked)')
        n_tokens = mask.sum()
    else:
        n_tokens = target_log_probs.numel()
    debug_fxn(n_tokens, 'n_tokens')

    return target_log_probs, n_tokens
Example #12
0
    def forward(self, inputs: torch.Tensor,
                mask: torch.BoolTensor) -> torch.Tensor:
        assert len(inputs.shape) == 3
        assert len(mask.shape) == 2
        assert inputs.shape[:-1] == mask.shape

        _, seq_len, _ = inputs.shape

        sequence_lengths = mask.sum(-1)

        packed_inputs = nn.utils.rnn.pack_padded_sequence(inputs,
                                                          sequence_lengths,
                                                          batch_first=True,
                                                          enforce_sorted=False)

        packed_outputs, (h_n, c_n) = super(LstmWrapper,
                                           self).forward(packed_inputs)

        output, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs,
                                                     batch_first=True,
                                                     total_length=seq_len)

        return output, (h_n, c_n)
Example #13
0
    def _loss(
            self, pred: torch.Tensor, true: torch.Tensor,
            mask: torch.BoolTensor,
            sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        BATCH_SIZE, N, M = pred.size()
        assert N == M
        SENTENCE_LENGTH = N

        valid_positions = mask.sum()

        result = []
        # Ignore first pred dimension as it is ROOT token prediction
        for i in range(SENTENCE_LENGTH - 1):
            pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH)
            true_i = true[:, i].reshape(-1)
            mask_i = mask[:, i]
            cross_entropy_loss = utils.masked_cross_entropy(
                pred_i, true_i, mask_i)
            result.append(cross_entropy_loss)
        cycle_loss = self._cycle_loss(pred)
        loss = torch.stack(result).transpose(1,
                                             0) * sample_weights.unsqueeze(-1)
        return loss.sum() / valid_positions + cycle_loss.mean(
        ), cycle_loss.mean()
Example #14
0
    def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
              sample_weights: torch.Tensor) -> torch.Tensor:
        assert pred.size() == true.size()
        BATCH_SIZE, _, MORPHOLOGICAL_FEATURES = pred.size()

        valid_positions = mask.sum()

        pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES)
        true = true.reshape(-1, MORPHOLOGICAL_FEATURES)
        mask = mask.reshape(-1)
        loss = None
        loss_func = utils.masked_cross_entropy
        for cat, cat_indices in self.slices.items():
            if cat not in ["__PAD__", "_"]:
                if loss is None:
                    loss = loss_func(pred[:, cat_indices],
                                     true[:, cat_indices].argmax(dim=1),
                                     mask)
                else:
                    loss += loss_func(pred[:, cat_indices],
                                      true[:, cat_indices].argmax(dim=1),
                                      mask)
        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
        return loss.sum() / valid_positions
Example #15
0
def add_sentence_boundary_token_ids(
    tensor: torch.Tensor, mask: torch.BoolTensor, sentence_begin_token: Any, sentence_end_token: Any
) -> Tuple[torch.Tensor, torch.BoolTensor]:
    sequence_lengths = mask.sum(dim=1).detach().cpu().numpy()
    tensor_shape = list(tensor.data.shape)
    new_shape = list(tensor_shape)
    new_shape[1] = tensor_shape[1] + 2
    tensor_with_boundary_tokens = tensor.new_zeros(*new_shape)
    if len(tensor_shape) == 2:
        tensor_with_boundary_tokens[:, 1:-1] = tensor
        tensor_with_boundary_tokens[:, 0] = sentence_begin_token
        for i, j in enumerate(sequence_lengths):
            tensor_with_boundary_tokens[i, j + 1] = sentence_end_token
        new_mask = tensor_with_boundary_tokens != 0
    elif len(tensor_shape) == 3:
        tensor_with_boundary_tokens[:, 1:-1, :] = tensor
        for i, j in enumerate(sequence_lengths):
            tensor_with_boundary_tokens[i, 0, :] = sentence_begin_token
            tensor_with_boundary_tokens[i, j + 1, :] = sentence_end_token
        new_mask = (tensor_with_boundary_tokens > 0).sum(dim=-1) > 0
    else:
        raise ValueError("add_sentence_boundary_token_ids only accepts 2D and 3D input")

    return tensor_with_boundary_tokens, new_mask
    def _unfold_long_sequences(
        self,
        embeddings: torch.FloatTensor,
        mask: torch.BoolTensor,
        batch_size: int,
        num_segment_concat_wordpieces: int,
    ) -> torch.FloatTensor:
        """
        We take 2D segments of a long sequence and flatten them out to get the whole sequence
        representation while remove unnecessary special tokens.

        [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ]
        -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ]

        We truncate the start and end tokens for all segments, recombine the segments,
        and manually add back the start and end tokens.

        # Parameters

        embeddings: `torch.FloatTensor`
            Shape: [batch_size * num_segments, self._max_length, embedding_size].
        mask: `torch.BoolTensor`
            Shape: [batch_size * num_segments, self._max_length].
            The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask`
            in `forward()`.
        batch_size: `int`
        num_segment_concat_wordpieces: `int`
            The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e.
            the original `token_ids.size(1)`.

        # Returns:

        embeddings: `torch.FloatTensor`
            Shape: [batch_size, self._num_wordpieces, embedding_size].
        """
        def lengths_to_mask(lengths, max_len, device):
            return torch.arange(max_len, device=device).expand(
                lengths.size(0), max_len) < lengths.unsqueeze(1)

        device = embeddings.device
        num_segments = int(embeddings.size(0) / batch_size)
        embedding_size = embeddings.size(2)

        # We want to remove all segment-level special tokens but maintain sequence-level ones
        num_wordpieces = num_segment_concat_wordpieces - (
            num_segments - 1) * self._num_added_tokens

        embeddings = embeddings.reshape(batch_size,
                                        num_segments * self._max_length,
                                        embedding_size)
        mask = mask.reshape(batch_size, num_segments * self._max_length)
        # We assume that all 1s in the mask precede all 0s, and add an assert for that.
        # Open an issue on GitHub if this breaks for you.
        # Shape: (batch_size,)
        seq_lengths = mask.sum(-1)
        if not (lengths_to_mask(seq_lengths, mask.size(1), device)
                == mask).all():
            raise ValueError(
                "Long sequence splitting only supports masks with all 1s preceding all 0s."
            )
        # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op
        end_token_indices = (
            seq_lengths.unsqueeze(-1) -
            torch.arange(self._num_added_end_tokens, device=device) - 1)

        # Shape: (batch_size, self._num_added_start_tokens, embedding_size)
        start_token_embeddings = embeddings[:, :self.
                                            _num_added_start_tokens, :]
        # Shape: (batch_size, self._num_added_end_tokens, embedding_size)
        end_token_embeddings = batched_index_select(embeddings,
                                                    end_token_indices)

        embeddings = embeddings.reshape(batch_size, num_segments,
                                        self._max_length, embedding_size)
        embeddings = embeddings[:, :, self._num_added_start_tokens:-self.
                                _num_added_end_tokens, :]  # truncate segment-level start/end tokens
        embeddings = embeddings.reshape(batch_size, -1,
                                        embedding_size)  # flatten

        # Now try to put end token embeddings back which is a little tricky.

        # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation.
        # Shape: (batch_size,)
        num_effective_segments = (seq_lengths + self._max_length -
                                  1) / self._max_length
        # The number of indices that end tokens should shift back.
        num_removed_non_end_tokens = (
            num_effective_segments * self._num_added_tokens -
            self._num_added_end_tokens)
        # Shape: (batch_size, self._num_added_end_tokens)
        end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1)
        assert (end_token_indices >= self._num_added_start_tokens).all()
        # Add space for end embeddings
        embeddings = torch.cat(
            [embeddings, torch.zeros_like(end_token_embeddings)], 1)
        # Add end token embeddings back
        embeddings.scatter_(
            1,
            end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings),
            end_token_embeddings)

        # Now put back start tokens. We can do this before putting back end tokens, but then
        # we need to change `num_removed_non_end_tokens` a little.
        embeddings = torch.cat([start_token_embeddings, embeddings], 1)

        # Truncate to original length
        embeddings = embeddings[:, :num_wordpieces, :]
        return embeddings
Example #17
0
    def _construct_loss(
        self,
        head_tag: torch.Tensor,
        child_tag: torch.Tensor,
        score_arc: torch.Tensor,
        head_indices: torch.Tensor,
        head_tags: torch.Tensor,
        mask: torch.BoolTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        # Parameters

        head_tag : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        score_arc : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to
            generate a distribution over attachments of a given word to all other words.
        head_indices : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : `torch.BoolTensor`, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        # Returns

        arc_nll : `torch.Tensor`, required.
            The negative log likelihood from the arc loss.
        tag_nll : `torch.Tensor`, required.
            The negative log likelihood from the arc tag loss.
        """
        batch_size, sequence_length, _ = score_arc.size()
        # shape (batch_size, 1)
        range_vector = torch.arange(batch_size, device=score_arc.device).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = (
            masked_log_softmax(score_arc, mask) * mask.unsqueeze(2) * mask.unsqueeze(1)
        )

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag, child_tag, head_indices)
        normalised_head_tag_logits = masked_log_softmax(
            head_tag_logits, mask.unsqueeze(-1)
        ) * mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = torch.arange(sequence_length, device=score_arc.device)
        child_index = (
            timestep_index.view(1, sequence_length)
            .expand(batch_size, sequence_length)
            .long()
        )
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
Example #18
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        parent_mask: torch.BoolTensor,
        parent_start_mask: torch.BoolTensor,
        parent_end_mask: torch.BoolTensor,
        child_mask: torch.BoolTensor = None,
        parent_idxs: torch.LongTensor = None,
        parent_tags: torch.LongTensor = None,
        parent_starts: torch.BoolTensor = None,
        parent_ends: torch.BoolTensor = None,
        child_idxs: torch.BoolTensor = None,
        child_starts: torch.BoolTensor = None,
        child_ends: torch.BoolTensor = None,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            parent_mask: [batch_size, num_words]
            parent_start_mask: [batch_size, num_words]
            parent_end_mask: [batch_size, num_words]
            child_mask: [batch_size, num_words]
            parent_idxs: [batch_size]
            parent_tags: [batch_size]
            parent_starts: [batch_size]
            parent_ends: [batch_size]
            child_idxs: [batch_size, num_words]
            child_starts: [batch_size, num_words]
            child_ends: [batch_size, num_words]
        Returns:
            parent_probs: [batch_size, num_words]
            parent_tag_probs: [batch_size, num_words, num_tags]
            parent_start_probs: [batch_size, num_words]
            parent_end_probs: [batch_size, num_words]
            child_probs: [batch_size, num_words]
            child_start_probs: [batch_size, num_words]
            child_end_probs: [batch_size, num_words]
            arc_loss (if parent_idx is not None)
            tag_loss (if parent_idxs and parent_tags are not None)
            start_loss (if parent_starts is not None)
            end_loss (if parent_ends is not None)
            child_loss (if child_idxs is not None)
            child_start_loss (if child_starts is not None)
            child_end_loss (if child_ends is not None)
        """

        cls_embedding, embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                # bert = self.bert if self.arch == "bert" else self.roberta
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)
        parent_start_scores = self.parent_start_feedforward(
            encoded_text).squeeze(-1)
        parent_end_scores = self.parent_end_feedforward(encoded_text).squeeze(
            -1)

        # mask out impossible positions
        minus_inf = -1e8
        parent_mask = torch.logical_and(parent_mask, word_mask)
        parent_scores = parent_scores + (~parent_mask).float() * minus_inf
        parent_start_mask = torch.logical_and(parent_start_mask, word_mask)
        parent_start_scores = parent_start_scores + (
            ~parent_start_mask).float() * minus_inf
        parent_end_mask = torch.logical_and(parent_end_mask, word_mask)
        parent_end_scores = parent_end_scores + (
            ~parent_end_mask).float() * minus_inf

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_start_probs = F.softmax(parent_start_scores, dim=-1)
        parent_end_probs = F.softmax(parent_end_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)

        output = (parent_probs, parent_tag_probs, parent_start_probs,
                  parent_end_probs)

        if self.config.predict_child:
            child_scores = self.child_feedforward(encoded_text).squeeze(-1)
            child_start_scores = self.child_start_feedforward(
                encoded_text).squeeze(-1)
            child_end_scores = self.child_end_feedforward(
                encoded_text).squeeze(-1)
            # todo add child mask - child should be inside the origin span
            if child_mask is None:
                child_mask = torch.ones_like(word_mask)
            else:
                child_mask = torch.logical_and(child_mask, word_mask)
            child_scores = child_scores + (~child_mask).float() * minus_inf
            child_start_scores = child_start_scores + (
                ~child_mask).float() * minus_inf
            child_end_scores = child_end_scores + (
                ~child_mask).float() * minus_inf
            child_probs = torch.sigmoid(child_scores)
            child_start_probs = torch.sigmoid(child_start_scores)
            child_end_probs = torch.sigmoid(child_end_scores)
            output = output + (child_probs, child_start_probs, child_end_probs)

        # add losses
        batch_range_vector = get_range_vector(
            batch_size, get_device_of(encoded_text))  # [bsz]
        if parent_idxs is not None:
            # [bsz, seq_len]
            parent_logits = F.log_softmax(parent_scores, dim=-1)
            parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs]
            parent_arc_nll = parent_arc_nll.mean()
            output = output + (parent_arc_nll, )

            if parent_tags is not None:
                parent_tag_nll = F.cross_entropy(
                    parent_tag_scores[batch_range_vector, parent_idxs],
                    parent_tags)
                output = output + (parent_tag_nll, )

        if parent_starts is not None:
            # [bsz, seq_len]
            parent_start_logits = F.log_softmax(parent_start_scores, dim=-1)
            parent_start_nll = -parent_start_logits[batch_range_vector,
                                                    parent_starts].mean()
            output = output + (parent_start_nll, )

        if parent_ends is not None:
            # [bsz, seq_len]
            parent_end_logits = F.log_softmax(parent_end_scores, dim=-1)
            parent_end_nll = -parent_end_logits[batch_range_vector,
                                                parent_ends].mean()
            output = output + (parent_end_nll, )

        if self.config.predict_child:
            if child_idxs is not None:
                child_loss = F.binary_cross_entropy_with_logits(
                    child_scores, child_idxs.float(), reduction="none")
                child_loss = (child_loss *
                              child_mask).sum() / (child_mask.sum() + 1e-8)
                output = output + (child_loss, )
            if child_starts is not None:
                child_start_loss = F.binary_cross_entropy_with_logits(
                    child_start_scores, child_starts.float(), reduction="none")
                child_start_loss = (child_start_loss * child_mask).sum() / (
                    child_mask.sum() + 1e-8)
                output = output + (child_start_loss, )
            if child_ends is not None:
                child_end_loss = F.binary_cross_entropy_with_logits(
                    child_end_scores, child_ends.float(), reduction="none")
                child_end_loss = (child_end_loss *
                                  child_mask).sum() / (child_mask.sum() + 1e-8)
                output = output + (child_end_loss, )

        return output
Example #19
0
    def forward(self, tokens: torch.Tensor, mask: torch.BoolTensor):
        if mask is not None:
            tokens = tokens * mask.unsqueeze(-1)
        else:
            # If mask doesn't exist create one of shape (batch_size, num_tokens)
            mask = torch.ones(tokens.shape[0],
                              tokens.shape[1],
                              device=tokens.device).bool()

        # Our input is expected to have shape `(batch_size, num_tokens, embedding_dim)`.  The
        # convolution layers expect input of shape `(batch_size, in_channels, sequence_length)`,
        # where the conv layer `in_channels` is our `embedding_dim`.  We thus need to transpose the
        # tensor first.
        tokens = torch.transpose(tokens, 1, 2)
        # Each convolution layer returns output of size `(batch_size, num_filters, pool_length)`,
        # where `pool_length = num_tokens - ngram_size + 1`.  We then do an activation function,
        # masking, then do max pooling over each filter for the whole input sequence.
        # Because our max pooling is simple, we just use `torch.max`.  The resultant tensor has shape
        # `(batch_size, num_conv_layers * num_filters)`, which then gets projected using the
        # projection layer, if requested.

        # To ensure the cnn_encoder respects masking we add a large negative value to
        # the activations of all filters that convolved over a masked token. We do this by
        # first enumerating all filters for a given convolution size (torch.arange())
        # then by comparing it to an index of the last filter that does not involve a masked
        # token (.ge()) and finally adjusting dimensions to allow for addition and multiplying
        # by a large negative value (.unsqueeze())
        filter_outputs = []
        batch_size = tokens.shape[0]
        # shape: (batch_size, 1)
        last_unmasked_tokens = mask.sum(dim=1).unsqueeze(dim=-1)
        for i in range(len(self._convolution_layers)):
            convolution_layer = getattr(self, "conv_layer_{}".format(i))
            pool_length = tokens.shape[2] - convolution_layer.kernel_size[0] + 1

            # Forward pass of the convolutions.
            # shape: (batch_size, num_filters, pool_length)
            activations = self._activation(convolution_layer(tokens))

            # Create activation mask.
            # shape: (batch_size, pool_length)
            indices = (torch.arange(
                pool_length, device=activations.device).unsqueeze(0).expand(
                    batch_size, pool_length))
            # shape: (batch_size, pool_length)
            activations_mask = indices.ge(last_unmasked_tokens -
                                          convolution_layer.kernel_size[0] + 1)
            # shape: (batch_size, num_filters, pool_length)
            activations_mask = activations_mask.unsqueeze(1).expand_as(
                activations)

            # Replace masked out values with smallest possible value of the dtype so
            # that max pooling will ignore these activations.
            # shape: (batch_size, pool_length)
            activations = activations + (activations_mask *
                                         min_value_of_dtype(activations.dtype))

            # Pick out the max filters
            filter_outputs.append(activations.max(dim=2)[0])

        # Now we have a list of `num_conv_layers` tensors of shape `(batch_size, num_filters)`.
        # Concatenating them gives us a tensor of shape `(batch_size, num_filters * num_conv_layers)`.
        maxpool_output = (torch.cat(filter_outputs, dim=1)
                          if len(filter_outputs) > 1 else filter_outputs[0])

        # Replace the maxpool activations that picked up the masks with 0s
        maxpool_output[maxpool_output == min_value_of_dtype(
            maxpool_output.dtype)] = 0.0

        if self.projection_layer:
            result = self.projection_layer(maxpool_output)
        else:
            result = maxpool_output
        return result
Example #20
0
def get_lengths_from_binary_sequence_mask(mask: torch.BoolTensor) -> torch.LongTensor:
    return mask.sum(-1)