Пример #1
0
    def _compute_loss(self, target_tokens: Dict[str, torch.Tensor],
                      target_ids: torch.Tensor, salience_values: torch.Tensor,
                      predicted_salience: torch.Tensor,
                      state: Dict[str, torch.Tensor]):
        # (B, L, V)
        all_class_probs = state['all_class_probs']
        attentions = state['all_attentions']
        source_mask = state['source_mask']
        target_mask = util.get_text_field_mask(target_tokens)[:, :-1]
        target = target_ids[:, 1:]
        assert target_mask.size(1) == target.size(1)
        # (B, L, 1)
        length = all_class_probs.size(1)
        step_losses = all_class_probs.new_zeros((all_class_probs.size(0), ))
        target_mask_t = target_mask.transpose(0, 1).contiguous()
        coverages = state['all_coverages']
        coverage_losses = all_class_probs.new_zeros(
            (all_class_probs.size(0), ))
        # batch_size, length, class_size,  = all_class_probs.size()
        # gold_probs = torch.gather(
        #     all_class_probs.view(batch_size, class_size, length), 1,
        #     target_ids.unsqueeze(1).view(batch_size, 1, length))
        # nll_loss = -torch.log(gold_probs)
        # cov_loss = torch.min(attentions, coverages).sum(1).unsqueeze(1)
        # loss = nll_loss + self.coverage_lambda * cov_loss
        # loss = target_mask.unsqueeze(1) * loss

        for i in range(length):
            gold_probs = torch.gather(all_class_probs[:, i, :], 1,
                                      target[:, i].unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + 1e-7)
            if self.is_coverage:
                step_coverage_loss = torch.sum(
                    torch.min(attentions[:, :, i], coverages[:, :, i]), 1)
                step_loss = step_loss + self.coverage_lambda * step_coverage_loss
                step_coverage_loss = step_coverage_loss * target_mask_t[i]
                coverage_losses += step_coverage_loss
            step_loss = step_loss * target_mask_t[i]
            step_losses += step_loss

        if self.is_coverage:
            batch_coverage_loss = coverage_losses / \
                                  util.get_lengths_from_binary_sequence_mask(target_mask)
            total_coverage_loss = torch.mean(batch_coverage_loss)
            self.coverage_loss = total_coverage_loss.item()
        batch_avg_loss = step_losses / util.get_lengths_from_binary_sequence_mask(
            target_mask)
        total_loss = torch.mean(batch_avg_loss)
        # loss = self.criterion(
        #     all_class_log_probs[:, :-1, :].transpose(1, 2),
        #     target_ids[:, 1:all_class_log_probs.size(1)])
        # coverage_loss = torch.min(attentions, coverages).sum(1).mean()
        if predicted_salience is not None:
            predicted_salience = source_mask * predicted_salience.squeeze(2)
            salience_values = source_mask * salience_values
            salience_loss = self.prediction_criterion(predicted_salience,
                                                      salience_values)
            total_loss = total_loss + self.salience_lambda * salience_loss
            self.salience_MSE = salience_loss.item()
        return total_loss
Пример #2
0
    def calculate_loss(self, pred, target, mask):
        # pred: shape:[batchsize, seq_len, seq_len, 49]
        assert len(pred.shape) == 4 and len(target.shape) == 4
        assert pred.shape == target.shape
        assert mask.shape == pred.shape[:2]
        batchsize, seq_len, _, r_num = pred.shape
        # shape: (batchsize, )
        mask_length = get_lengths_from_binary_sequence_mask(mask)
        # 设置新的mask矩阵
        new_mask = torch.zeros(
            (batchsize, seq_len, seq_len, r_num)).to(mask.device)
        for i, v in enumerate(mask_length):
            v = v.item()
            new_mask[i, :v, :v, :] = 1
        count_all = mask_length.sum()
        youxiao = target.sum()
        assert count_all.item() >= youxiao.item()
        # 对一个batchsize的范围内,进行平均化
        # 这个是解决正负样本的方法

        new_mask[target == 1] = (count_all - youxiao) / (youxiao + 1e-20)
        return binary_cross_entropy_with_logits(pred,
                                                target,
                                                new_mask,
                                                reduction='sum')
Пример #3
0
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        ph: Do NOT perform Viterbi decoding - we are interested in learning dynamics, not best performance
        """
        all_predictions = output_dict['class_probabilities']
        sequence_lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [all_predictions[i].detach().cpu() for i in range(all_predictions.size(0))]
        else:
            predictions_list = [all_predictions]
        all_tags = []

        # ph: transition matrices contain only ones (and no -inf, which would signal illegal transition)
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])
        start_transitions = torch.zeros(num_labels)

        for predictions, length in zip(predictions_list, sequence_lengths):
            max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix,
                                                        allowed_start_transitions=start_transitions)
            tags = [self.vocab.get_token_from_index(x, namespace="labels")
                    for x in max_likelihood_sequence]
            all_tags.append(tags)
        output_dict['tags'] = all_tags
        return output_dict
Пример #4
0
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.

        NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
        decoding produces low quality output if you decode on top of word representations directly,
        because the model gets confused by the 'missing' positions (which is sensible as it is trained
        to perform tagging on wordpieces, not words).

        Secondly, it's important that the indices we use to recover words from the wordpieces are the
        start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
        tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
        when we select out the word tags from the wordpiece tags. This happens in the case that a word
        is split into multiple word pieces, and then we take the last tag of the word, which might
        correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        wordpiece_tags = []
        word_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        # **************** Different ********************
        # We add in the offsets here so we can compute the un-wordpieced tags.
        for predictions, length, offsets in zip(
                predictions_list, sequence_lengths,
                output_dict["wordpiece_offsets"]):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(
                    x, namespace=self._label_namespace)
                for x in max_likelihood_sequence
            ]

            wordpiece_tags.append(tags)
            if isinstance(self.bert_model,
                          PretrainedTransformerMismatchedEmbedder):
                word_tags.append(tags)
            else:
                word_tags.append([tags[i] for i in offsets])
            # print(word_tags)
        output_dict["wordpiece_tags"] = wordpiece_tags
        output_dict["tags"] = word_tags
        return output_dict
Пример #5
0
    def test_forward_pulls_out_correct_tensor_for_unsorted_batches(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2SeqWrapper(lstm)
        tensor = torch.rand([5, 7, 3])
        tensor[0, 3:, :] = 0
        tensor[1, 4:, :] = 0
        tensor[2, 2:, :] = 0
        tensor[3, 6:, :] = 0
        mask = torch.ones(5, 7)
        mask[0, 3:] = 0
        mask[1, 4:] = 0
        mask[2, 2:] = 0
        mask[3, 6:] = 0

        input_tensor = Variable(tensor)
        mask = Variable(mask)
        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, _ = sort_batch_by_length(input_tensor,
                                                                                              sequence_lengths)
        packed_sequence = pack_padded_sequence(sorted_inputs,
                                               sorted_sequence_lengths.data.tolist(),
                                               batch_first=True)
        lstm_output, _ = lstm(packed_sequence)
        encoder_output = encoder(input_tensor, mask)
        lstm_tensor, _ = pad_packed_sequence(lstm_output, batch_first=True)
        assert_almost_equal(encoder_output.data.numpy(),
                            lstm_tensor.index_select(0, restoration_indices).data.numpy())
    def setUp(self):
        super(TestEncoderBase, self).setUp()
        self.lstm = LSTM(bidirectional=True,
                         num_layers=3,
                         input_size=3,
                         hidden_size=7,
                         batch_first=True)
        self.encoder_base = _EncoderBase(stateful=True)

        tensor = Variable(torch.rand([5, 7, 3]))
        tensor[1, 6:, :] = 0
        tensor[3, 2:, :] = 0
        self.tensor = tensor
        mask = Variable(torch.ones(5, 7))
        mask[1, 6:] = 0
        mask[2, :] = 0  # <= completely masked
        mask[3, 2:] = 0
        mask[4, :] = 0  # <= completely masked
        self.mask = mask

        self.batch_size = 5
        self.num_valid = 3
        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        _, _, restoration_indices, sorting_indices = sort_batch_by_length(
            tensor, sequence_lengths)
        self.sorting_indices = sorting_indices
        self.restoration_indices = restoration_indices
    def test_forward_pulls_out_correct_tensor_with_sequence_lengths(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2VecWrapper(lstm)

        tensor = torch.rand([5, 7, 3])
        tensor[1, 6:, :] = 0
        tensor[2, 4:, :] = 0
        tensor[3, 2:, :] = 0
        tensor[4, 1:, :] = 0
        mask = torch.ones(5, 7)
        mask[1, 6:] = 0
        mask[2, 4:] = 0
        mask[3, 2:] = 0
        mask[4, 1:] = 0

        input_tensor = Variable(tensor)
        mask = Variable(mask)
        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        packed_sequence = pack_padded_sequence(input_tensor, list(sequence_lengths.data), batch_first=True)
        _, state = lstm(packed_sequence)
        # Transpose output state, extract the last forward and backward states and
        # reshape to be of dimension (batch_size, 2 * hidden_size).
        reshaped_state = state[0].transpose(0, 1)[:, -2:, :].contiguous()
        explicitly_concatenated_state = torch.cat([reshaped_state[:, 0, :].squeeze(1),
                                                   reshaped_state[:, 1, :].squeeze(1)], -1)
        encoder_output = encoder(input_tensor, mask)
        assert_almost_equal(encoder_output.data.numpy(), explicitly_concatenated_state.data.numpy())
Пример #8
0
    def forward(self, tokens, mask=None):  #pylint: disable=arguments-differ
        if mask is not None:
            tokens = tokens * mask.unsqueeze(-1).float()

        # Our input has shape `(batch_size, num_tokens, embedding_dim)`, so we sum out the `num_tokens`
        # dimension.
        summed = tokens.sum(1)

        if self._averaged:
            if mask is not None:
                lengths = get_lengths_from_binary_sequence_mask(mask)
                length_mask = (lengths > 0)

                # Set any length 0 to 1, to avoid dividing by zero.
                lengths = torch.max(lengths, lengths.new_ones(1))
            else:
                lengths = tokens.new_full((1, ), fill_value=tokens.size(1))
                length_mask = None

            summed = summed / lengths.unsqueeze(-1).float()

            if length_mask is not None:
                summed = summed * (length_mask > 0).float().unsqueeze(-1)

        return summed
Пример #9
0
    def test_forward_pulls_out_correct_tensor_with_sequence_lengths(self):
        lstm = LSTM(bidirectional=True,
                    num_layers=3,
                    input_size=3,
                    hidden_size=7,
                    batch_first=True)
        encoder = PytorchSeq2VecWrapper(lstm)

        input_tensor = torch.rand([5, 7, 3])
        input_tensor[1, 6:, :] = 0
        input_tensor[2, 4:, :] = 0
        input_tensor[3, 2:, :] = 0
        input_tensor[4, 1:, :] = 0
        mask = torch.ones(5, 7).bool()
        mask[1, 6:] = False
        mask[2, 4:] = False
        mask[3, 2:] = False
        mask[4, 1:] = False

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        packed_sequence = pack_padded_sequence(input_tensor,
                                               sequence_lengths.tolist(),
                                               batch_first=True)
        _, state = lstm(packed_sequence)
        # Transpose output state, extract the last forward and backward states and
        # reshape to be of dimension (batch_size, 2 * hidden_size).
        reshaped_state = state[0].transpose(0, 1)[:, -2:, :].contiguous()
        explicitly_concatenated_state = torch.cat([
            reshaped_state[:, 0, :].squeeze(1), reshaped_state[:,
                                                               1, :].squeeze(1)
        ], -1)
        encoder_output = encoder(input_tensor, mask)
        assert_almost_equal(encoder_output.data.numpy(),
                            explicitly_concatenated_state.data.numpy())
Пример #10
0
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        predicted_gram_vals = output_dict.pop("gram_vals").cpu().detach().numpy()
        predicted_lemmas = output_dict.pop("lemmas").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)

        assert len(head_tags) == len(heads) == len(lengths) == len(predicted_gram_vals) == len(predicted_lemmas)

        head_tag_labels, head_indices, decoded_gram_vals, decoded_lemmas = [], [], [], []
        for instance_index in range(len(head_tags)):
            instance_heads, instance_tags = heads[instance_index], head_tags[instance_index]
            words, length = output_dict["words"][instance_index], lengths[instance_index]
            gram_vals, lemmas = predicted_gram_vals[instance_index], predicted_lemmas[instance_index]

            words = words[: length.item() - 1]
            gram_vals = gram_vals[: length.item() - 1, :]
            lemmas = lemmas[: length.item() - 1, :]

            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

            inst_gram_vals = []
            for tok_gram_vals in gram_vals:
                dtgv = [self.vocab.get_token_from_index(gram_val, "grammar_value_tags") for gram_val in tok_gram_vals]
                inst_gram_vals.append(dtgv)
            decoded_gram_vals.append(inst_gram_vals)
#             print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n")
#             print( "ITLOG: decoded_gram_vals = {}".format(decoded_gram_vals) )
#             print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n")            

            inst_lemmas = []
            for word, word_lrules in zip(words, lemmas):
                dtl = [self.lemmatize_helper.lemmatize(word, lrule) for lrule in word_lrules]
                inst_lemmas.append(dtl)
            decoded_lemmas.append(inst_lemmas)

        if self.task_config.task_type == "multitask":
            output_dict["predicted_dependencies"] = head_tag_labels
            output_dict["predicted_heads"] = head_indices
            output_dict["predicted_gram_vals"] = decoded_gram_vals
            output_dict["predicted_lemmas"] = decoded_lemmas
        elif self.task_config.task_type == "single":
            if self.task_config.params["model"] == "morphology":
                output_dict["predicted_gram_vals"] = decoded_gram_vals
            elif self.task_config.params["model"] == "lemmatization":
                output_dict["predicted_lemmas"] = decoded_lemmas
            elif self.task_config.params["model"] == "syntax":
                output_dict["predicted_dependencies"] = head_tag_labels
                output_dict["predicted_heads"] = head_indices
            else:
                assert False, "Unknown model type {}".format(self.task_config.params["model"])
        else:
            assert False, "Unknown task type {}".format(self.task_config.task_type)

        return output_dict
Пример #11
0
 def test_get_sequence_lengths_from_binary_mask(self):
     binary_mask = torch.ByteTensor([[1, 1, 1, 0, 0, 0],
                                     [1, 1, 0, 0, 0, 0],
                                     [1, 1, 1, 1, 1, 1],
                                     [1, 0, 0, 0, 0, 0]])
     lengths = util.get_lengths_from_binary_sequence_mask(binary_mask)
     numpy.testing.assert_array_equal(lengths.numpy(), numpy.array([3, 2, 6, 1]))
Пример #12
0
 def test_get_sequence_lengths_from_binary_mask(self):
     binary_mask = torch.ByteTensor([[1, 1, 1, 0, 0, 0],
                                     [1, 1, 0, 0, 0, 0],
                                     [1, 1, 1, 1, 1, 1],
                                     [1, 0, 0, 0, 0, 0]])
     lengths = util.get_lengths_from_binary_sequence_mask(binary_mask)
     numpy.testing.assert_array_equal(lengths.numpy(), numpy.array([3, 2, 6, 1]))
    def test_forward_pulls_out_correct_tensor_for_unsorted_batches(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2SeqWrapper(lstm)
        input_tensor = torch.rand([5, 7, 3])
        input_tensor[0, 3:, :] = 0
        input_tensor[1, 4:, :] = 0
        input_tensor[2, 2:, :] = 0
        input_tensor[3, 6:, :] = 0
        mask = torch.ones(5, 7)
        mask[0, 3:] = 0
        mask[1, 4:] = 0
        mask[2, 2:] = 0
        mask[3, 6:] = 0

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, _ = sort_batch_by_length(input_tensor,
                                                                                              sequence_lengths)
        packed_sequence = pack_padded_sequence(sorted_inputs,
                                               sorted_sequence_lengths.data.tolist(),
                                               batch_first=True)
        lstm_output, _ = lstm(packed_sequence)
        encoder_output = encoder(input_tensor, mask)
        lstm_tensor, _ = pad_packed_sequence(lstm_output, batch_first=True)
        assert_almost_equal(encoder_output.data.numpy(),
                            lstm_tensor.index_select(0, restoration_indices).data.numpy())
Пример #14
0
    def test_forward_pulls_out_correct_tensor_with_unsorted_batches(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2VecWrapper(lstm)

        tensor = torch.rand([5, 7, 3])
        tensor[0, 3:, :] = 0
        tensor[1, 4:, :] = 0
        tensor[2, 2:, :] = 0
        tensor[3, 6:, :] = 0
        mask = torch.ones(5, 7)
        mask[0, 3:] = 0
        mask[1, 4:] = 0
        mask[2, 2:] = 0
        mask[3, 6:] = 0

        input_tensor = Variable(tensor)
        mask = Variable(mask)
        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices = sort_batch_by_length(input_tensor,
                                                                                           sequence_lengths)
        packed_sequence = pack_padded_sequence(sorted_inputs,
                                               sorted_sequence_lengths.data.tolist(),
                                               batch_first=True)
        _, state = lstm(packed_sequence)
        # Transpose output state, extract the last forward and backward states and
        # reshape to be of dimension (batch_size, 2 * hidden_size).
        sorted_transposed_state = state[0].transpose(0, 1).index_select(0, restoration_indices)
        reshaped_state = sorted_transposed_state[:, -2:, :].contiguous()
        explicitly_concatenated_state = torch.cat([reshaped_state[:, 0, :].squeeze(1),
                                                   reshaped_state[:, 1, :].squeeze(1)], -1)
        encoder_output = encoder(input_tensor, mask)
        assert_almost_equal(encoder_output.data.numpy(), explicitly_concatenated_state.data.numpy())
Пример #15
0
    def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None):  #pylint: disable=arguments-differ
        if mask is not None:
            tokens = tokens * mask.unsqueeze(-1).float()

        # Our input has shape `(batch_size, num_tokens, embedding_dim)`, so we sum out the `num_tokens`
        # dimension.
        summed = tokens.sum(1)

        if self._averaged:
            if mask is not None:
                lengths = get_lengths_from_binary_sequence_mask(mask)
                length_mask = (lengths > 0)

                # Set any length 0 to 1, to avoid dividing by zero.
                lengths = torch.max(lengths, Variable(lengths.data.new().resize_(1).fill_(1)))
            else:
                lengths = Variable(tokens.data.new().resize_(1).fill_(tokens.size(1)), requires_grad=False)
                length_mask = None

            summed = summed / lengths.unsqueeze(-1).float()

            if length_mask is not None:
                summed = summed * (length_mask > 0).float().unsqueeze(-1)

        return summed
    def decode(self, output_dict: Dict[str, torch.Tensor]):
        """
        Takes the result of forward and creates human readable, non-padded dependency trees.
        :param output_dict:
        :return: output_dict with two new keys, "predicted_labels" and "predicted_heads", which are lists of lists.
        """

        head_tags = output_dict.pop("edge_labels").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)
        edge_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(
                heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [
                self.vocab.get_token_from_index(label, "head_tags")
                for label in instance_tags
            ]
            edge_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_labels"] = edge_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict
Пример #17
0
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:

        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)
        head_tag_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(
                heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [
                self.vocab.get_token_from_index(label,
                                                self.task_type + "_head_tags")
                for label in instance_tags
            ]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_dependencies"] = head_tag_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict
Пример #18
0
    def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None):  #pylint: disable=arguments-differ
        if mask is not None:
            tokens = tokens * mask.unsqueeze(-1).float()

        # Our input has shape `(batch_size, num_tokens, embedding_dim)`, so we sum out the `num_tokens`
        # dimension.
        summed = tokens.sum(1)

        if self._averaged:
            if mask is not None:
                lengths = get_lengths_from_binary_sequence_mask(mask)
                length_mask = (lengths > 0)

                # Set any length 0 to 1, to avoid dividing by zero.
                lengths = torch.max(lengths, Variable(lengths.data.new().resize_(1).fill_(1)))
            else:
                lengths = Variable(tokens.data.new().resize_(1).fill_(tokens.size(1)), requires_grad=False)
                length_mask = None

            summed = summed / lengths.unsqueeze(-1).float()

            if length_mask is not None:
                summed = summed * (length_mask > 0).float().unsqueeze(-1)

        return summed
    def test_forward_pulls_out_correct_tensor_with_unsorted_batches(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2VecWrapper(lstm)

        input_tensor = torch.rand([5, 7, 3])
        input_tensor[0, 3:, :] = 0
        input_tensor[1, 4:, :] = 0
        input_tensor[2, 2:, :] = 0
        input_tensor[3, 6:, :] = 0
        mask = torch.ones(5, 7)
        mask[0, 3:] = 0
        mask[1, 4:] = 0
        mask[2, 2:] = 0
        mask[3, 6:] = 0

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, _ = sort_batch_by_length(input_tensor,
                                                                                              sequence_lengths)
        packed_sequence = pack_padded_sequence(sorted_inputs,
                                               sorted_sequence_lengths.tolist(),
                                               batch_first=True)
        _, state = lstm(packed_sequence)
        # Transpose output state, extract the last forward and backward states and
        # reshape to be of dimension (batch_size, 2 * hidden_size).
        sorted_transposed_state = state[0].transpose(0, 1).index_select(0, restoration_indices)
        reshaped_state = sorted_transposed_state[:, -2:, :].contiguous()
        explicitly_concatenated_state = torch.cat([reshaped_state[:, 0, :].squeeze(1),
                                                   reshaped_state[:, 1, :].squeeze(1)], -1)
        encoder_output = encoder(input_tensor, mask)
        assert_almost_equal(encoder_output.data.numpy(), explicitly_concatenated_state.data.numpy())
Пример #20
0
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        ``"tags"`` key to the dictionary with the result.
        """
        all_predictions = output_dict['class_probabilities']
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        for predictions, length in zip(predictions_list, sequence_lengths):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length], transition_matrix)
            tags = [
                self.vocab.get_token_from_index(x, namespace="labels")
                for x in max_likelihood_sequence
            ]
            all_tags.append(tags)
        output_dict['tags'] = all_tags
        return output_dict
Пример #21
0
    def test_forward_pulls_out_correct_tensor_with_sequence_lengths(self):
        lstm = LSTM(bidirectional=True,
                    num_layers=3,
                    input_size=3,
                    hidden_size=7,
                    batch_first=True)
        encoder = PytorchSeq2SeqWrapper(lstm)
        input_tensor = torch.rand([5, 7, 3])
        input_tensor[1, 6:, :] = 0
        input_tensor[2, 4:, :] = 0
        input_tensor[3, 2:, :] = 0
        input_tensor[4, 1:, :] = 0
        mask = torch.ones(5, 7).bool()
        mask[1, 6:] = False
        mask[2, 4:] = False
        mask[3, 2:] = False
        mask[4, 1:] = False

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        packed_sequence = pack_padded_sequence(input_tensor,
                                               sequence_lengths.data.tolist(),
                                               batch_first=True)
        lstm_output, _ = lstm(packed_sequence)
        encoder_output = encoder(input_tensor, mask)
        lstm_tensor, _ = pad_packed_sequence(lstm_output, batch_first=True)
        assert_almost_equal(encoder_output.data.numpy(),
                            lstm_tensor.data.numpy())
Пример #22
0
    def sort_and_run_forward(
            self,
            module: Callable[[PackedSequence, Optional[RnnState]],
                             Tuple[Union[PackedSequence, torch.Tensor],
                                   RnnState]],
            inputs: torch.Tensor,
            mask: torch.Tensor,
            hidden_states: Optional[RnnState] = None,
            reset_hidden_state=False):
        # First count how many sequences are empty.
        batch_size = mask.size(0)
        num_valid = torch.sum(mask[:, 0]).int().item()

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices =\
            sort_batch_by_length(inputs, sequence_lengths)

        # Now create a PackedSequence with only the non-empty, sorted sequences.
        packed_sequence_input = pack_padded_sequence(
            sorted_inputs[:num_valid, :, :],
            sorted_sequence_lengths[:num_valid].data.tolist(),
            batch_first=True)
        # Prepare the initial states.
        initial_states, hidden_states = self._get_initial_states(
            batch_size, num_valid, sorting_indices, hidden_states)

        if reset_hidden_state:
            initial_states = None

        # Actually call the module on the sorted PackedSequence.
        module_output, final_states = module(packed_sequence_input,
                                             initial_states)

        return module_output, final_states, restoration_indices, hidden_states
Пример #23
0
    def forward(self, hidden_states, attention_mask, pad_start=None):
        all_hidden_states = [hidden_states]
        all_attentions = (None,)
        debug_info = {}

        batch_size, seq_length, dim = hidden_states.size()

        num_layers = min(len(self.layer), seq_length - 1)
        root_answers = torch.zeros(batch_size, dim).float().to(hidden_states.device)

        sentence_lengths = util.get_lengths_from_binary_sequence_mask(attention_mask)

        for i, layer_module in enumerate(self.layer[:num_layers]):
            layer_outputs = layer_module(all_hidden_states, i)
            hidden_states, attentions, layer_debug_info = layer_outputs

            # Add last layer
            all_hidden_states += [hidden_states]
            all_attentions = all_attentions + (attentions,)

            # we should only update the root hidden states when sentence is shorter than num_layers
            # so we keep a mask of the relevant batch indices
            update_root_indices_mask = (sentence_lengths - 2 >= i).float()
            update_root_indices_mask = update_root_indices_mask.unsqueeze(1).to(hidden_states.device)

            answer, answer_debug_info = self.answer_nn(hidden_states[:, 0])
            layer_debug_info.update(answer_debug_info)

            root_answers = root_answers * (1 - update_root_indices_mask) + update_root_indices_mask * answer

            for key, val in layer_debug_info.items():
                debug_info[f'{key}_{i+1}'] = val

        outputs = (root_answers, all_hidden_states, all_attentions, debug_info)
        return outputs  # last-layer hidden state, all hidden states, all attentions
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        arc_tag_probs = output_dict["arc_tag_probs"].cpu().detach().numpy()
        arc_probs = output_dict["arc_probs"].cpu().detach().numpy()
        mask = output_dict["mask"]
        lengths = get_lengths_from_binary_sequence_mask(mask)
        arcs = []
        arc_tags = []
        for instance_arc_probs, instance_arc_tag_probs, length in zip(
                arc_probs, arc_tag_probs, lengths):

            arc_matrix = instance_arc_probs > self.edge_prediction_threshold
            edges = []
            edge_tags = []
            for i in range(length):
                for j in range(length):
                    if arc_matrix[i, j] == 1:
                        edges.append((i, j))
                        tag = instance_arc_tag_probs[i, j].argmax(-1)
                        edge_tags.append(
                            self.vocab.get_token_from_index(tag, "labels"))
            arcs.append(edges)
            arc_tags.append(edge_tags)

        output_dict["arcs"] = arcs
        output_dict["arc_tags"] = arc_tags
        return output_dict
Пример #25
0
    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: torch.Tensor,
            mask: torch.Tensor,
            hidden_state: torch.Tensor = None) -> torch.Tensor:

        if mask is None:
            return self._module(inputs, hidden_state)[0]

        # In some circumstances you may have sequences of zero length.
        # ``pack_padded_sequence`` requires all sequence lengths to be > 0, so here we
        # adjust the ``mask`` so that every sequence has length at least 1. Then after
        # running the RNN we zero out the corresponding rows in the result.

        # First count how many sequences are empty.
        batch_size, total_sequence_length = mask.size()
        num_valid = torch.sum(mask[:, 0]).int().data[0]

        # Force every sequence to be length at least one. Need to `.clone()` the mask
        # to avoid a RuntimeError from shared storage.
        if num_valid < batch_size:
            mask = mask.clone()
            mask[:, 0] = 1

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices = sort_batch_by_length(
            inputs, sequence_lengths)
        packed_sequence_input = pack_padded_sequence(
            sorted_inputs,
            sorted_sequence_lengths.data.tolist(),
            batch_first=True)

        # Actually call the module on the sorted PackedSequence.
        packed_sequence_output, _ = self._module(packed_sequence_input,
                                                 hidden_state)
        unpacked_sequence_tensor, _ = pad_packed_sequence(
            packed_sequence_output, batch_first=True)

        # We sorted by length, so if there are invalid rows that need to be zeroed out
        # they will be at the end.
        if num_valid < batch_size:
            unpacked_sequence_tensor[num_valid:, :, :] = 0.

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2SeqEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - unpacked_sequence_tensor.size(
            1)
        if sequence_length_difference > 0:
            zeros = unpacked_sequence_tensor.data.new(
                batch_size, sequence_length_difference,
                unpacked_sequence_tensor.size(-1)).fill_(0)
            zeros = torch.autograd.Variable(zeros)
            unpacked_sequence_tensor = torch.cat(
                [unpacked_sequence_tensor, zeros], 1)

        # Restore the original indices and return the sequence.
        return unpacked_sequence_tensor.index_select(0, restoration_indices)
Пример #26
0
    def forward(
        self,
        encoder_outs: Dict[str, Any],
        target_tokens: Optional[TextFieldTensors] = None,
    ) -> Dict[str, torch.Tensor]:
        """Implementation of the forward pass.

        Make a forward pass of the decoder, given the latent vector and the sent as references.
        Notice explanation in simple_seq2seq.py for creating relavant targets and mask

        Notice the indexing here, contrast with indexing in _get_reconstruction_loss
        """
        # TODO: Right now we are always using teacher forcing if target tokens are available.
        # Change this to use scheduled sampling
        z = encoder_outs['z']
        batch_size = z.size(0)
        if target_tokens is not None:
            target_mask = get_text_field_mask(target_tokens)
            relevant_targets = {
                'tokens': {
                    'tokens': target_tokens['tokens']['tokens'][:, :-1]
                }
            }
            relevant_mask = target_mask[:, :-1]

            # num_tokens should technically be target_mask[:, 1:] but the lengths turn out the same
            # So we make use of this precomputed slice
            num_tokens = (
                get_lengths_from_binary_sequence_mask(relevant_mask) -
                1).sum()
            embeddings, state = self._prepare_decoder(z, relevant_targets)
            logits = self._run_decoder(embeddings, relevant_mask, state, z)

            class_probabilities = torch_f.softmax(logits, 2)
            _, best_predictions = torch.max(class_probabilities, 2)

            # Notice that we are not using relevant targets here
            loss = self._get_reconstruction_loss(
                logits, target_tokens['tokens']['tokens'], target_mask)
            output_dict = {
                'logits': logits,
                'predictions': best_predictions,
                'loss': loss
            }
            if not self.training:
                self._bleu(output_dict['predictions'],
                           target_tokens['tokens']['tokens'])
        else:
            output_dict = self.generate(z, target_tokens)
            # This computation of num_tokens is technically wrong, but we don't really care
            # about computing this kind of PPL for this case so we leave it here
            # TODO verify -2 because both start and end ?
            num_tokens = (output_dict['predictions'].size(1) - 2) * batch_size

        nll = output_dict['loss'] + encoder_outs['kl']
        self._nll(nll)
        self._ppl(nll * batch_size, num_tokens)
        return output_dict
Пример #27
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, sequence_length, num_classes).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, sequence_length). It must be the same
            shape as the ``predictions`` tensor without the ``num_classes`` dimension.
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        """
        if mask is None:
            mask = torch.ones(gold_labels.size())
        # If you actually passed in Variables here instead of Tensors, this will be a huge memory
        # leak, because it will prevent garbage collection for the computation graph.  We'll ensure
        # that we're using tensors here first.
        if isinstance(predictions, Variable):
            predictions = predictions.data.cpu()
        if isinstance(gold_labels, Variable):
            gold_labels = gold_labels.data.cpu()
        if isinstance(mask, Variable):
            mask = mask.data.cpu()

        num_classes = predictions.size(-1)
        if (gold_labels >= num_classes).any():
            raise ConfigurationError(
                "A gold label passed to SpanBasedF1Measure contains an "
                "id >= {}, the number of classes.".format(num_classes))

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        argmax_predictions = predictions.max(-1)[1].float()

        # Iterate over timesteps in batch.
        batch_size = gold_labels.size(0)
        for i in range(batch_size):
            sequence_prediction = argmax_predictions[i, :]
            sequence_gold_label = gold_labels[i, :]
            length = sequence_lengths[i]
            prediction_spans = self._extract_spans(
                sequence_prediction[:length].tolist())
            gold_spans = self._extract_spans(
                sequence_gold_label[:length].tolist())

            for span in prediction_spans:
                if span in gold_spans:
                    self._true_positives[span[1]] += 1
                    gold_spans.remove(span)
                else:
                    self._false_positives[span[1]] += 1
            # These spans weren't predicted.
            for span in gold_spans:
                self._false_negatives[span[1]] += 1
Пример #28
0
 def test_decode_runs_correctly(self):
     training_tensors = self.dataset.as_tensor_dict()
     output_dict = self.model(**training_tensors)
     decode_output_dict = self.model.decode(output_dict)
     lengths = get_lengths_from_binary_sequence_mask(decode_output_dict[u"mask"]).data.tolist()
     # Hard to check anything concrete which we haven't checked in the above
     # test, so we'll just check that the tags are equal to the lengths
     # of the individual instances, rather than the max length.
     for prediction, length in izip(decode_output_dict[u"tags"], lengths):
         assert len(prediction) == length
Пример #29
0
 def test_decode_runs_correctly(self):
     training_tensors = self.dataset.as_tensor_dict()
     output_dict = self.model(**training_tensors)
     decode_output_dict = self.model.decode(output_dict)
     lengths = get_lengths_from_binary_sequence_mask(decode_output_dict["mask"]).data.tolist()
     # Hard to check anything concrete which we haven't checked in the above
     # test, so we'll just check that the tags are equal to the lengths
     # of the individual instances, rather than the max length.
     for prediction, length in zip(decode_output_dict["tags"], lengths):
         assert len(prediction) == length
Пример #30
0
    def forward(
            self,
            tokens: Dict[str, torch.Tensor],
            passage: [Passage],
            dataset_label: [str],
            spans: torch.Tensor,
            lang: torch.Tensor,
            id: [str] = None,
            gold_ucca_tree: [Passage] = None,
            gold_primary_tree: [InternalParseNode] = None,
            span_labels: torch.Tensor = None,
            remote_heads: torch.Tensor = None,
            remote_deps: torch.Tensor = None,
            remote_labels: torch.Tensor = None,
            remote_nodes_spans: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        embedded_text_input = self.token_embedder(tokens, lang=lang)
        text_mask = get_text_field_mask(tokens)
        sentence_lengths = get_lengths_from_binary_sequence_mask(text_mask)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()

        encoded_text = self.encoder(embedded_text_input, text_mask)
        # span_representations.shape: torch.Size([1, 276, 20])
        span_representations = self.span_extractor(encoded_text, spans,
                                                   text_mask, span_mask)

        predicted_tree = self.span_decoder.predict(span_representations,
                                                   sentence_lengths)
        predict_passages = list(map(to_UCCA, passage, predicted_tree))
        predict_passages = self.remote_parser.restore_remote(
            predict_passages, span_representations, sentence_lengths)
        output = {"prediction": predict_passages}

        if gold_ucca_tree is not None:
            tree_loss = self.span_decoder.get_loss(span_representations,
                                                   sentence_lengths,
                                                   gold_primary_tree)
            # Looking at the span start index is enough to know if
            # this is padding or not. Shape: (batch_size, num_spans)
            remote_nodes_spans_mask = (remote_nodes_spans[:, :, 0] >=
                                       0).squeeze(-1).long()
            remote_labels_mask = (remote_labels[:, :] >= 0).squeeze(-1).long()
            remote_loss = self.remote_parser.get_loss(
                span_representations, sentence_lengths, remote_nodes_spans,
                remote_nodes_spans_mask, remote_heads, remote_deps,
                remote_labels, remote_labels_mask)

            output["loss"] = tree_loss + remote_loss
            for i in range(len(predicted_tree)):
                self.evaluator(dataset_label[i], predict_passages[i],
                               gold_ucca_tree[i])

        return output
Пример #31
0
 def test_get_sequence_lengths_converts_to_long_tensor_and_avoids_variable_overflow(self):
     # Tests the following weird behaviour in Pytorch 0.1.12
     # doesn't happen for our sequence masks:
     #
     # mask = torch.ones([260]).byte()
     # mask.sum() # equals 260.
     # var_mask = t.a.V(mask)
     # var_mask.sum() # equals 4, due to 8 bit precision - the sum overflows.
     binary_mask = torch.ones(2, 260).byte()
     lengths = util.get_lengths_from_binary_sequence_mask(binary_mask)
     numpy.testing.assert_array_equal(lengths.data.numpy(), numpy.array([260, 260]))
Пример #32
0
 def test_get_sequence_lengths_converts_to_long_tensor_and_avoids_variable_overflow(self):
     # Tests the following weird behaviour in Pytorch 0.1.12
     # doesn't happen for our sequence masks:
     #
     # mask = torch.ones([260]).byte()
     # mask.sum() # equals 260.
     # var_mask = t.a.V(mask)
     # var_mask.sum() # equals 4, due to 8 bit precision - the sum overflows.
     binary_mask = torch.ones(2, 260).byte()
     lengths = util.get_lengths_from_binary_sequence_mask(binary_mask)
     numpy.testing.assert_array_equal(lengths.data.numpy(), numpy.array([260, 260]))
Пример #33
0
 def init_enc_state(
         self,
         source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
     source_mask = util.get_text_field_mask(source_tokens)
     source_lengths = get_lengths_from_binary_sequence_mask(source_mask)
     state = {
         'source_mask': source_mask,  # (B, L)
         'source_lengths': source_lengths,  # (L)
         'source_tokens': source_tokens['tokens'],
     }
     return state
Пример #34
0
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:

        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        predicted_gram_vals = output_dict.pop(
            "gram_vals").cpu().detach().numpy()
        predicted_lemmas = output_dict.pop("lemmas").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)

        assert len(head_tags) == len(heads) == len(lengths) == len(
            predicted_gram_vals) == len(predicted_lemmas)

        head_tag_labels, head_indices, decoded_gram_vals, decoded_lemmas = [], [], [], []
        for instance_index in range(len(head_tags)):
            instance_heads, instance_tags = heads[instance_index], head_tags[
                instance_index]
            words, length = output_dict["words"][instance_index], lengths[
                instance_index]
            gram_vals, lemmas = predicted_gram_vals[
                instance_index], predicted_lemmas[instance_index]

            words = words[:length.item() - 1]
            gram_vals = gram_vals[:length.item() - 1]
            lemmas = lemmas[:length.item() - 1]

            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [
                self.vocab.get_token_from_index(label, "head_tags")
                for label in instance_tags
            ]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

            decoded_gram_vals.append([
                self.vocab.get_token_from_index(gram_val, "grammar_value_tags")
                for gram_val in gram_vals
            ])

            decoded_lemmas.append([
                self.lemmatize_helper.lemmatize(word, lemmatize_rule_index)
                for word, lemmatize_rule_index in zip(words, lemmas)
            ])

        output_dict["predicted_dependencies"] = head_tag_labels
        output_dict["predicted_heads"] = head_indices
        output_dict["predicted_gram_vals"] = decoded_gram_vals
        output_dict["predicted_lemmas"] = decoded_lemmas

        return output_dict
Пример #35
0
    def sort_and_run_forward(
            self,
            module: Callable[[PackedSequence, Optional[RnnState]],
                             Tuple[Union[PackedSequence, torch.Tensor],
                                   RnnState], ],
            inputs: torch.Tensor,
            mask: torch.Tensor,
            hidden_state: Optional[RnnState] = None,
            prevs=None,
            rev_prevs=None):
        # In some circumstances you may have sequences of zero length. ``pack_padded_sequence``
        # requires all sequence lengths to be > 0, so remove sequences of zero length before
        # calling self._module, then fill with zeros.

        # First count how many sequences are empty.
        batch_size = mask.size(0)
        num_valid = torch.sum(mask[:, 0]).int().item()

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices = sort_batch_by_length(
            inputs, sequence_lengths)

        prevs = [prevs[i] for i in sorting_indices][:num_valid]
        rev_prevs = [rev_prevs[i] for i in sorting_indices][:num_valid]

        # Now create a PackedSequence with only the non-empty, sorted sequences.
        packed_sequence_input = pack_padded_sequence(
            sorted_inputs[:num_valid, :, :],
            sorted_sequence_lengths[:num_valid].data.tolist(),
            batch_first=True,
        )
        # Prepare the initial states.
        if not self.stateful:
            if hidden_state is None:
                initial_states: Any = hidden_state
            elif isinstance(hidden_state, tuple):
                initial_states = [
                    state.index_select(
                        1, sorting_indices)[:, :num_valid, :].contiguous()
                    for state in hidden_state
                ]
            else:
                initial_states = hidden_state.index_select(
                    1, sorting_indices)[:, :num_valid, :].contiguous()
        else:
            initial_states = self._get_initial_states(batch_size, num_valid,
                                                      sorting_indices)

        # Actually call the module on the sorted PackedSequence.
        module_output, final_states = module(packed_sequence_input,
                                             initial_states, prevs, rev_prevs)

        return module_output, final_states, restoration_indices
Пример #36
0
    def only_cle(self, output_dict: Dict[str, torch.Tensor]):
        """
        Therefore, we take the result of forward and perform the following steps (for each sentence in batch):
        - remove padding
        :param output_dict: result of forward
        :return: output_dict with the following keys added:
            - lexlabels: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root)
            - supertags: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root)
        """
        full_label_logits = output_dict.pop("full_label_logits").cpu().detach(
        ).numpy()  #shape (batch size, seq len, seq len, num edge labels)
        edge_existence_scores = output_dict.pop(
            "edge_existence_scores").cpu().detach().numpy(
            )  #shape (batch size, seq len, seq len, num edge labels)
        heads = output_dict.pop("heads")
        heads_cpu = heads.cpu().detach().numpy()
        mask = output_dict.pop("mask")
        edge_label_logits = output_dict.pop("label_logits").cpu().detach(
        ).numpy()  # shape (batch_size, seq_len, num edge labels)

        output_dict.pop("encoded_text_parsing")
        output_dict.pop("encoded_text_tagging")  #don't need that

        lengths = get_lengths_from_binary_sequence_mask(mask)

        #here we collect things, in the end we will have one entry for each sentence:
        all_edge_label_logits = []
        head_indices = []
        all_full_label_logits = []
        all_edge_existence_scores = []

        for i, length in enumerate(lengths):
            instance_heads_cpu = list(heads_cpu[i, 1:length])
            #apply changes to instance_heads tensor:
            instance_heads = heads[i, :]
            for j, x in enumerate(instance_heads_cpu):
                instance_heads[j + 1] = torch.tensor(
                    x
                )  #+1 because we removed the first position from instance_heads_cpu

            all_edge_label_logits.append(edge_label_logits[i, 1:length, :])

            all_full_label_logits.append(
                full_label_logits[i, :length, :length, :])
            all_edge_existence_scores.append(
                edge_existence_scores[i, :length, :length])
            head_indices.append(instance_heads_cpu)

        output_dict["label_logits"] = all_edge_label_logits
        output_dict["predicted_heads"] = head_indices
        output_dict["full_label_logits"] = all_full_label_logits
        output_dict["edge_existence_scores"] = all_edge_existence_scores
        return output_dict
def write_predictions_viterbi_decoded(serialization_dir, split, epoch, predicted_tags,
                              vocab: Vocabulary,
                              tokens: Dict[str, torch.LongTensor],
                                verb_indicator: torch.LongTensor,
                                tags: torch.LongTensor = None,
                                pos_tags: torch.LongTensor = None,
                                spans: torch.LongTensor = None,
                                span_labels: torch.LongTensor = None,
                                metadata: Any = None):

    prediction_file_path = os.path.join(serialization_dir, "predictions", "predictions-" + split + "-" + str(epoch) + ".txt")
    gold_file_path = os.path.join(serialization_dir, "predictions", "gold-" + split + "-" + str(epoch) + ".txt")

    if not os.path.exists(os.path.dirname(prediction_file_path)):
        try:
            os.makedirs(os.path.dirname(prediction_file_path))
        except OSError as exc:  # Guard against race condition
            if exc.errno != errno.EEXIST:
                raise

    # logger.info("Writing gold srl tags (in conll file format) to %s", gold_file_path)
    # logger.info("Writing predicted srl tags (in conll file format) to %s", prediction_file_path)

    prediction_file = open(prediction_file_path, "a+")
    gold_file = open(gold_file_path, "a+")

    sentences = tokens["tokens"]
    mask = get_text_field_mask(tokens)
    sentence_lengths = get_lengths_from_binary_sequence_mask(mask).data.tolist()

    for sentence, _gold_tags, _verb_indicator, _length, _predicted_tags in zip(sentences.data.cpu(), tags.data.cpu(),
                                                            verb_indicator.data.cpu(), sentence_lengths, predicted_tags.data.cpu()):
        tokens = [vocab.get_token_from_index(x, namespace="tokens").__str__()
                  for x in sentence[:_length]]
        gold_labels = [vocab.get_token_from_index(x, namespace="labels")
                         for x in _gold_tags[:_length]]
        _verb_indicator = [x for x in _verb_indicator[: _length]]

        prediction = [vocab.get_token_from_index(x, namespace="labels")
                         for x in _predicted_tags[:_length]]

        try:
            verb_index = _verb_indicator.index(1)
        except ValueError:
            verb_index = None

        # Defined in semantic_role_labeler model implementation
        write_to_conll_eval_file(prediction_file=prediction_file, gold_file=gold_file,
                                 verb_index=verb_index, sentence=tokens, prediction=prediction,
                                 gold_labels=gold_labels)
    prediction_file.close()
    gold_file.close()
Пример #38
0
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

        head_tags = output_dict["head_tags"].cpu().detach().numpy()
        heads = output_dict["heads"].cpu().detach().numpy()
        lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"])
        head_tag_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [self.vocab.get_token_from_index(label, "head_tags")
                      for label in instance_tags]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_dependencies"] = head_tag_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict
    def test_forward_pulls_out_correct_tensor_with_sequence_lengths(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2SeqWrapper(lstm)
        input_tensor = torch.rand([5, 7, 3])
        input_tensor[1, 6:, :] = 0
        input_tensor[2, 4:, :] = 0
        input_tensor[3, 2:, :] = 0
        input_tensor[4, 1:, :] = 0
        mask = torch.ones(5, 7)
        mask[1, 6:] = 0
        mask[2, 4:] = 0
        mask[3, 2:] = 0
        mask[4, 1:] = 0

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        packed_sequence = pack_padded_sequence(input_tensor, sequence_lengths.data.tolist(), batch_first=True)
        lstm_output, _ = lstm(packed_sequence)
        encoder_output = encoder(input_tensor, mask)
        lstm_tensor, _ = pad_packed_sequence(lstm_output, batch_first=True)
        assert_almost_equal(encoder_output.data.numpy(), lstm_tensor.data.numpy())
Пример #40
0
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        ``"tags"`` key to the dictionary with the result.
        """
        all_predictions = output_dict['class_probabilities']
        sequence_lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [all_predictions[i].detach().cpu() for i in range(all_predictions.size(0))]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        for predictions, length in zip(predictions_list, sequence_lengths):
            max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix)
            tags = [self.vocab.get_token_from_index(x, namespace="labels")
                    for x in max_likelihood_sequence]
            all_tags.append(tags)
        output_dict['tags'] = all_tags
        return output_dict
Пример #41
0
    def setUp(self):
        super(TestEncoderBase, self).setUp()
        self.lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        self.encoder_base = _EncoderBase(stateful=True)

        tensor = Variable(torch.rand([5, 7, 3]))
        tensor[1, 6:, :] = 0
        tensor[3, 2:, :] = 0
        self.tensor = tensor
        mask = Variable(torch.ones(5, 7))
        mask[1, 6:] = 0
        mask[2, :] = 0  # <= completely masked
        mask[3, 2:] = 0
        mask[4, :] = 0  # <= completely masked
        self.mask = mask

        self.batch_size = 5
        self.num_valid = 3
        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        _, _, restoration_indices, sorting_indices = sort_batch_by_length(tensor, sequence_lengths)
        self.sorting_indices = sorting_indices
        self.restoration_indices = restoration_indices
Пример #42
0
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        arc_tag_probs = output_dict["arc_tag_probs"].cpu().detach().numpy()
        arc_probs = output_dict["arc_probs"].cpu().detach().numpy()
        mask = output_dict["mask"]
        lengths = get_lengths_from_binary_sequence_mask(mask)
        arcs = []
        arc_tags = []
        for instance_arc_probs, instance_arc_tag_probs, length in zip(arc_probs, arc_tag_probs, lengths):

            arc_matrix = instance_arc_probs > self.edge_prediction_threshold
            edges = []
            edge_tags = []
            for i in range(length):
                for j in range(length):
                    if arc_matrix[i, j] == 1:
                        edges.append((i, j))
                        tag = instance_arc_tag_probs[i, j].argmax(-1)
                        edge_tags.append(self.vocab.get_token_from_index(tag, "labels"))
            arcs.append(edges)
            arc_tags.append(edge_tags)

        output_dict["arcs"] = arcs
        output_dict["arc_tags"] = arc_tags
        return output_dict
Пример #43
0
    def forward(self,
                context_1: torch.Tensor,
                mask_1: torch.Tensor,
                context_2: torch.Tensor,
                mask_2: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        # pylint: disable=arguments-differ
        """
        Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral
        matching functions between them in one direction.

        Parameters
        ----------
        context_1 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence.
        mask_1 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len1), indicating which
            positions in the first sentence are padding (0) and which are not (1).
        context_2 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence.
        mask_2 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len2), indicating which
            positions in the second sentence are padding (0) and which are not (1).

        Returns
        -------
        A tuple of matching vectors for the two sentences. Each of which is a list of
        matching vectors of shape (batch, seq_len, num_perspectives or 1)
        """
        assert (not mask_2.requires_grad) and (not mask_1.requires_grad)
        assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim

        # (batch,)
        len_1 = get_lengths_from_binary_sequence_mask(mask_1)
        len_2 = get_lengths_from_binary_sequence_mask(mask_2)

        # (batch, seq_len*)
        mask_1, mask_2 = mask_1.float(), mask_2.float()

        # explicitly set masked weights to zero
        # (batch_size, seq_len*, hidden_dim)
        context_1 = context_1 * mask_1.unsqueeze(-1)
        context_2 = context_2 * mask_2.unsqueeze(-1)

        # array to keep the matching vectors for the two sentences
        matching_vector_1: List[torch.Tensor] = []
        matching_vector_2: List[torch.Tensor] = []

        # Step 0. unweighted cosine
        # First calculate the cosine similarities between each forward
        # (or backward) contextual embedding and every forward (or backward)
        # contextual embedding of the other sentence.

        # (batch, seq_len1, seq_len2)
        cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2), context_2.unsqueeze(-3), dim=3)

        # (batch, seq_len*, 1)
        cosine_max_1 = masked_max(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True)
        cosine_mean_1 = masked_mean(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True)
        cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True)
        cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True)

        matching_vector_1.extend([cosine_max_1, cosine_mean_1])
        matching_vector_2.extend([cosine_max_2, cosine_mean_2])

        # Step 1. Full-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with the last time step of the forward (or backward)
        # contextual embedding of the other sentence
        if self.with_full_match:

            # (batch, 1, hidden_dim)
            if self.is_forward:
                # (batch, 1, hidden_dim)
                last_position_1 = (len_1 - 1).clamp(min=0)
                last_position_1 = last_position_1.view(-1, 1, 1).expand(-1, 1, self.hidden_dim)
                last_position_2 = (len_2 - 1).clamp(min=0)
                last_position_2 = last_position_2.view(-1, 1, 1).expand(-1, 1, self.hidden_dim)

                context_1_last = context_1.gather(1, last_position_1)
                context_2_last = context_2.gather(1, last_position_2)
            else:
                context_1_last = context_1[:, 0:1, :]
                context_2_last = context_2[:, 0:1, :]

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_full = multi_perspective_match(context_1,
                                                             context_2_last,
                                                             self.full_match_weights)
            matching_vector_2_full = multi_perspective_match(context_2,
                                                             context_1_last,
                                                             self.full_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_full)
            matching_vector_2.extend(matching_vector_2_full)

        # Step 2. Maxpooling-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with every time step of the forward (or backward)
        # contextual embedding of the other sentence, and only the max value of each
        # dimension is retained.
        if self.with_maxpool_match:
            # (batch, seq_len1, seq_len2, num_perspectives)
            matching_vector_max = multi_perspective_match_pairwise(context_1,
                                                                   context_2,
                                                                   self.maxpool_match_weights)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_max = masked_max(matching_vector_max,
                                               mask_2.unsqueeze(-2).unsqueeze(-1),
                                               dim=2)
            matching_vector_1_mean = masked_mean(matching_vector_max,
                                                 mask_2.unsqueeze(-2).unsqueeze(-1),
                                                 dim=2)
            matching_vector_2_max = masked_max(matching_vector_max.permute(0, 2, 1, 3),
                                               mask_1.unsqueeze(-2).unsqueeze(-1),
                                               dim=2)
            matching_vector_2_mean = masked_mean(matching_vector_max.permute(0, 2, 1, 3),
                                                 mask_1.unsqueeze(-2).unsqueeze(-1),
                                                 dim=2)

            matching_vector_1.extend([matching_vector_1_max, matching_vector_1_mean])
            matching_vector_2.extend([matching_vector_2_max, matching_vector_2_mean])


        # Step 3. Attentive-Matching
        # Each forward (or backward) similarity is taken as the weight
        # of the forward (or backward) contextual embedding, and calculate an
        # attentive vector for the sentence by weighted summing all its
        # contextual embeddings.
        # Finally match each forward (or backward) contextual embedding
        # with its corresponding attentive vector.

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1)

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1)

        if self.with_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1))
            att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1))

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_mean = multi_perspective_match(context_1,
                                                                 att_mean_2,
                                                                 self.attentive_match_weights)
            matching_vector_2_att_mean = multi_perspective_match(context_2,
                                                                 att_mean_1,
                                                                 self.attentive_match_weights_reversed)
            matching_vector_1.extend(matching_vector_1_att_mean)
            matching_vector_2.extend(matching_vector_2_att_mean)

        # Step 4. Max-Attentive-Matching
        # Pick the contextual embeddings with the highest cosine similarity as the attentive
        # vector, and match each forward (or backward) contextual embedding with its
        # corresponding attentive vector.
        if self.with_max_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_max_2 = masked_max(att_2, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            att_max_1 = masked_max(att_1.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_max = multi_perspective_match(context_1,
                                                                att_max_2,
                                                                self.max_attentive_match_weights)
            matching_vector_2_att_max = multi_perspective_match(context_2,
                                                                att_max_1,
                                                                self.max_attentive_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_att_max)
            matching_vector_2.extend(matching_vector_2_att_max)

        return matching_vector_1, matching_vector_2
Пример #44
0
    def sort_and_run_forward(self,
                             module: Callable[[PackedSequence, Optional[RnnState]],
                                              Tuple[Union[PackedSequence, torch.Tensor], RnnState]],
                             inputs: torch.Tensor,
                             mask: torch.Tensor,
                             hidden_state: Optional[RnnState] = None):
        """
        This function exists because Pytorch RNNs require that their inputs be sorted
        before being passed as input. As all of our Seq2xxxEncoders use this functionality,
        it is provided in a base class. This method can be called on any module which
        takes as input a ``PackedSequence`` and some ``hidden_state``, which can either be a
        tuple of tensors or a tensor.

        As all of our Seq2xxxEncoders have different return types, we return `sorted`
        outputs from the module, which is called directly. Additionally, we return the
        indices into the batch dimension required to restore the tensor to it's correct,
        unsorted order and the number of valid batch elements (i.e the number of elements
        in the batch which are not completely masked). This un-sorting and re-padding
        of the module outputs is left to the subclasses because their outputs have different
        types and handling them smoothly here is difficult.

        Parameters
        ----------
        module : ``Callable[[PackedSequence, Optional[RnnState]],
                            Tuple[Union[PackedSequence, torch.Tensor], RnnState]]``, required.
            A function to run on the inputs. In most cases, this is a ``torch.nn.Module``.
        inputs : ``torch.Tensor``, required.
            A tensor of shape ``(batch_size, sequence_length, embedding_size)`` representing
            the inputs to the Encoder.
        mask : ``torch.Tensor``, required.
            A tensor of shape ``(batch_size, sequence_length)``, representing masked and
            non-masked elements of the sequence for each element in the batch.
        hidden_state : ``Optional[RnnState]``, (default = None).
            A single tensor of shape (num_layers, batch_size, hidden_size) representing the
            state of an RNN with or a tuple of
            tensors of shapes (num_layers, batch_size, hidden_size) and
            (num_layers, batch_size, memory_size), representing the hidden state and memory
            state of an LSTM-like RNN.

        Returns
        -------
        module_output : ``Union[torch.Tensor, PackedSequence]``.
            A Tensor or PackedSequence representing the output of the Pytorch Module.
            The batch size dimension will be equal to ``num_valid``, as sequences of zero
            length are clipped off before the module is called, as Pytorch cannot handle
            zero length sequences.
        final_states : ``Optional[RnnState]``
            A Tensor representing the hidden state of the Pytorch Module. This can either
            be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in
            the case of a GRU, or a tuple of tensors, such as those required for an LSTM.
        restoration_indices : ``torch.LongTensor``
            A tensor of shape ``(batch_size,)``, describing the re-indexing required to transform
            the outputs back to their original batch order.
        """
        # In some circumstances you may have sequences of zero length. ``pack_padded_sequence``
        # requires all sequence lengths to be > 0, so remove sequences of zero length before
        # calling self._module, then fill with zeros.

        # First count how many sequences are empty.
        batch_size = mask.size(0)
        num_valid = torch.sum(mask[:, 0]).int().item()

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices =\
            sort_batch_by_length(inputs, sequence_lengths)

        # Now create a PackedSequence with only the non-empty, sorted sequences.
        packed_sequence_input = pack_padded_sequence(sorted_inputs[:num_valid, :, :],
                                                     sorted_sequence_lengths[:num_valid].data.tolist(),
                                                     batch_first=True)
        # Prepare the initial states.
        if not self.stateful:
            if hidden_state is None:
                initial_states = hidden_state
            elif isinstance(hidden_state, tuple):
                initial_states = [state.index_select(1, sorting_indices)[:, :num_valid, :].contiguous()
                                  for state in hidden_state]
            else:
                initial_states = hidden_state.index_select(1, sorting_indices)[:, :num_valid, :].contiguous()

        else:
            initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices)

        # Actually call the module on the sorted PackedSequence.
        module_output, final_states = module(packed_sequence_input, initial_states)

        return module_output, final_states, restoration_indices
Пример #45
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                pos_tags: Dict[str, torch.LongTensor] = None,
                span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
                "class_probabilities": class_probabilities,
                "spans": spans,
                "tokens": [meta["tokens"] for meta in metadata],
                "pos_tags": [meta.get("pos_tags") for meta in metadata],
                "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
                                              for tree in batch_gold_trees]
            predicted_trees = self.construct_trees(class_probabilities.cpu().data,
                                                   spans.cpu().data,
                                                   num_spans.data,
                                                   output_dict["tokens"],
                                                   gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Пример #46
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None,
                 prediction_map: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, sequence_length, num_classes).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, sequence_length). It must be the same
            shape as the ``predictions`` tensor without the ``num_classes`` dimension.
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        prediction_map: ``torch.Tensor``, optional (default = None).
            A tensor of size (batch_size, num_classes) which provides a mapping from the index of predictions
            to the indices of the label vocabulary. If provided, the output label at each timestep will be
            ``vocabulary.get_index_to_token_vocabulary(prediction_map[batch, argmax(predictions[batch, t]))``,
            rather than simply ``vocabulary.get_index_to_token_vocabulary(argmax(predictions[batch, t]))``.
            This is useful in cases where each Instance in the dataset is associated with a different possible
            subset of labels from a large label-space (IE FrameNet, where each frame has a different set of
            possible roles associated with it).
        """
        if mask is None:
            mask = torch.ones_like(gold_labels)

        predictions, gold_labels, mask, prediction_map = self.unwrap_to_tensors(predictions,
                                                                                gold_labels,
                                                                                mask, prediction_map)

        num_classes = predictions.size(-1)
        if (gold_labels >= num_classes).any():
            raise ConfigurationError("A gold label passed to SpanBasedF1Measure contains an "
                                     "id >= {}, the number of classes.".format(num_classes))

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        argmax_predictions = predictions.max(-1)[1]

        if prediction_map is not None:
            argmax_predictions = torch.gather(prediction_map, 1, argmax_predictions)
            gold_labels = torch.gather(prediction_map, 1, gold_labels.long())

        argmax_predictions = argmax_predictions.float()

        # Iterate over timesteps in batch.
        batch_size = gold_labels.size(0)
        for i in range(batch_size):
            sequence_prediction = argmax_predictions[i, :]
            sequence_gold_label = gold_labels[i, :]
            length = sequence_lengths[i]

            if length == 0:
                # It is possible to call this metric with sequences which are
                # completely padded. These contribute nothing, so we skip these rows.
                continue

            predicted_string_labels = [self._label_vocabulary[label_id]
                                       for label_id in sequence_prediction[:length].tolist()]
            gold_string_labels = [self._label_vocabulary[label_id]
                                  for label_id in sequence_gold_label[:length].tolist()]

            tags_to_spans_function = None
            # `label_encoding` is empty and `tags_to_spans_function` is provided.
            if self._label_encoding is None and self._tags_to_spans_function:
                tags_to_spans_function = self._tags_to_spans_function
            # Search by `label_encoding`.
            elif self._label_encoding == "BIO":
                tags_to_spans_function = bio_tags_to_spans
            elif self._label_encoding == "IOB1":
                tags_to_spans_function = iob1_tags_to_spans
            elif self._label_encoding == "BIOUL":
                tags_to_spans_function = bioul_tags_to_spans
            elif self._label_encoding == "BMES":
                tags_to_spans_function = bmes_tags_to_spans

            predicted_spans = tags_to_spans_function(predicted_string_labels, self._ignore_classes)
            gold_spans = tags_to_spans_function(gold_string_labels, self._ignore_classes)

            predicted_spans = self._handle_continued_spans(predicted_spans)
            gold_spans = self._handle_continued_spans(gold_spans)

            for span in predicted_spans:
                if span in gold_spans:
                    self._true_positives[span[0]] += 1
                    gold_spans.remove(span)
                else:
                    self._false_positives[span[0]] += 1
            # These spans weren't predicted.
            for span in gold_spans:
                self._false_negatives[span[0]] += 1
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1)
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP ``SpanField`` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1)

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1))
        exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
            raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, "
                             f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                             f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                             f"{sequence_lengths}.")

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts)
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends)
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            float_end_sentinel_mask = end_sentinel_mask.float()
            float_start_sentinel_mask = start_sentinel_mask.float()
            forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel
            backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \
                                        + float_end_sentinel_mask * self._end_sentinel

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(self._forward_combination,
                                             [forward_start_embeddings, forward_end_embeddings])
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(self._backward_combination,
                                              [backward_start_embeddings, backward_end_embeddings])
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.float().unsqueeze(-1)
        return span_embeddings