def batchify(self, batch, device):
        examples = list()
        sentence_len_s = [len(tup[1]) for tup in batch]
        sentence_len_t = [len(tup[2]) for tup in batch]

        max_sentence_len_s = max(sentence_len_s)
        max_sentence_len_t = max(sentence_len_t)

        event1_lens = [len(tup[2]) for tup in batch]
        event2_lens = [len(tup[3]) for tup in batch]

        sentences_s, sentences_t, event1, event2, data_y = list(), list(
        ), list(), list(), list()
        for data in batch:
            sentences_s.append(data[1])
            sentences_t.append(data[2])
            event1.append(data[3])
            event2.append(data[4])
            y = self.y_label[data[5]] if data[5] in self.y_label else 0
            data_y.append(y)
            examples.append(data)

        sentences_s = list(
            map(lambda x: pad_sequence_to_length(x, max_sentence_len_s),
                sentences_s))
        sentences_t = list(
            map(lambda x: pad_sequence_to_length(x, max_sentence_len_t),
                sentences_t))

        event1 = list(map(lambda x: pad_sequence_to_length(x, 5), event1))
        event2 = list(map(lambda x: pad_sequence_to_length(x, 5), event2))

        mask_sentences_s = get_mask_from_sequence_lengths(
            torch.LongTensor(sentence_len_s), max_sentence_len_s)
        mask_sentences_t = get_mask_from_sequence_lengths(
            torch.LongTensor(sentence_len_t), max_sentence_len_t)

        mask_even1 = get_mask_from_sequence_lengths(
            torch.LongTensor(event1_lens), 5)
        mask_even2 = get_mask_from_sequence_lengths(
            torch.LongTensor(event2_lens), 5)

        return [
            torch.LongTensor(sentences_s).to(device),
            mask_sentences_s.to(device),
            torch.LongTensor(sentences_t).to(device),
            mask_sentences_t.to(device),
            torch.LongTensor(event1).to(device),
            mask_even1.to(device),
            torch.LongTensor(event2).to(device),
            mask_even2.to(device),
            torch.LongTensor(data_y).to(device), examples
        ]
Exemplo n.º 2
0
    def predict_crf(self,
                    hs,
                    ls=None,
                    lengths=None,
                    calculate_loss=True,
                    decode=False):
        device = hs.device
        if lengths is None:
            lengths = torch.tensor([h.shape[0] for h in hs], device=device)
        mask = get_mask_from_sequence_lengths(lengths, max_length=max(lengths))
        if not decode or self.crf_top_k == 1:
            ps = self.crf.viterbi_tags(hs, mask)
            ps, score = zip(*ps)
        else:
            ps = []
            psks = self.crf.viterbi_tags(hs, mask, top_k=self.crf_top_k)
            for psk in psks:
                psk, score = zip(*psk)
                ps.append(psk)

        if calculate_loss:
            log_likelihood = self.crf(hs, ls, mask)
            loss = -1 * log_likelihood / len(lengths)
        else:
            loss = torch.tensor(np.array(0), dtype=torch.float, device=device)

        return loss, ps
Exemplo n.º 3
0
    def _encoding(self, word_inputs: torch.Tensor,
                  chars_inputs: torch.Tensor,
                  lengths: torch.Tensor,):
        # NOTE: there is no dropout on the last layer.
        start = time.time()
        embedded_tokens = self.token_embedder(word_inputs, chars_inputs)
        self.token_embedding_time += time.time() - start

        start = time.time()
        mask = get_mask_from_sequence_lengths(lengths, lengths.max())

        if self.add_sentence_boundary:
            embedded_tokens_with_boundary, mask_with_boundary = \
                self._add_sentence_boundary(embedded_tokens, mask)
            encoded_tokens = self.encoder(embedded_tokens_with_boundary,
                                          mask_with_boundary)

            self.encoding_time += time.time() - start
            return encoded_tokens[:, :, 1:-1, :], embedded_tokens, mask
        elif self.add_sentence_boundary_ids:

            encoded_tokens = self.encoder(embedded_tokens, mask)
            self.encoding_time += time.time() - start
            return self._remove_sentence_boundaries(encoded_tokens, embedded_tokens, mask)
        else:
            encoded_tokens = self.encoder(embedded_tokens, mask)
            self.encoding_time += time.time() - start
            return encoded_tokens, embedded_tokens, mask
Exemplo n.º 4
0
    def batchify(self, batch):
        cur_batch_size = len(batch)

        encode_sequence_ipt = []
        decode_sequence_ipt = []

        for instance_ind in range(cur_batch_size):
            instance = batch[instance_ind]
            encode_sequence_ipt.append(instance[:] + [self.word2idx['<END>']])
            decode_sequence_ipt.append([self.word2idx['<BOS>']] + instance[:])

        lens = [len(tup) for tup in encode_sequence_ipt]
        max_len = max(lens)

        encode_sequence_ipt = list(
            map(lambda x: pad_sequence_to_length(x, max_len),
                encode_sequence_ipt))
        decode_sequence_ipt = list(
            map(lambda x: pad_sequence_to_length(x, max_len),
                decode_sequence_ipt))
        mask = get_mask_from_sequence_lengths(torch.LongTensor(lens), max_len)

        encode_sequence_ipt = torch.LongTensor(encode_sequence_ipt).to(
            self.device)
        decode_sequence_ipt = torch.LongTensor(decode_sequence_ipt).to(
            self.device)
        mask = mask.to(self.device)

        return [encode_sequence_ipt, decode_sequence_ipt, mask]
Exemplo n.º 5
0
    def forward(self, inputs: Dict[str, torch.Tensor], targets: torch.Tensor):
        # input_: (batch_size, seq_len)
        embedded_input = {}
        for name, fn in self.input_layers.items():
            input_ = inputs[name]
            embedded_input[name] = fn(input_)

        encoded_inputs = []
        for encoder_ in self.input_encoders:
            ordered_names = encoder_.get_ordered_names()
            args_ = {name: embedded_input[name] for name in ordered_names}
            encoded_inputs.append(self.input_dropout(encoder_(args_)))

        encoded_inputs = torch.cat(encoded_inputs, dim=-1)

        lengths = inputs['length']
        mask = get_mask_from_sequence_lengths(lengths, lengths.max())

        encoded_inputs = self.encoder(encoded_inputs, mask)
        # encoded_input_: (batch_size, seq_len, dim)

        encoded_inputs = self.dropout(encoded_inputs)

        output, loss = self.classify_layer(encoded_inputs, targets)

        return output, loss
Exemplo n.º 6
0
 def test_get_mask_from_sequence_lengths(self):
     sequence_lengths = Variable(torch.LongTensor([4, 3, 1, 4, 2]))
     mask = util.get_mask_from_sequence_lengths(sequence_lengths,
                                                5).data.numpy()
     assert_almost_equal(mask,
                         [[1, 1, 1, 1, 0], [1, 1, 1, 0, 0], [1, 0, 0, 0, 0],
                          [1, 1, 1, 1, 0], [1, 1, 0, 0, 0]])
Exemplo n.º 7
0
    def forward(self, word_inputs: torch.Tensor, char_inputs: torch.Tensor):
        embs = []
        if self.word_embedder is not None:
            word_inputs = torch.autograd.Variable(word_inputs,
                                                  requires_grad=False)
            embed_words = self.word_embedder(word_inputs)
            embs.append(embed_words)

        if self.char_embedder is not None:
            char_inputs, char_lengths = char_inputs
            batch_size, seq_len = char_lengths.size()[:2]
            char_inputs = char_inputs.view(batch_size * seq_len, -1)
            char_lengths = char_lengths.view(batch_size * seq_len)

            # (batch_size * seq_len, max_char, dim)
            embeded_chars = self.char_embedder(char_inputs)
            mask = get_mask_from_sequence_lengths(
                char_lengths, char_lengths.max()).unsqueeze(-1)
            float_mask = mask.float()
            embeded_chars = (embeded_chars * float_mask).sum(dim=-2)
            embs.append(embeded_chars.view(batch_size, seq_len, -1))

        token_embedding = torch.cat(embs, dim=2)

        return self.projection(token_embedding)
Exemplo n.º 8
0
    def forward(self, word_inputs: torch.Tensor, char_inputs: torch.Tensor):
        embs = []
        if self.word_embedder is not None:
            word_inputs = torch.autograd.Variable(word_inputs,
                                                  requires_grad=False)
            if self.use_cuda:
                word_inputs = word_inputs.cuda()
            word_emb = self.word_embedder(word_inputs)
            embs.append(word_emb)

        if self.char_embedder is not None:
            char_inputs, char_lengths = char_inputs
            batch_size, seq_len = char_lengths.size()

            char_inputs = char_inputs.view(batch_size * seq_len, -1)
            char_lengths = char_lengths.view(-1)
            char_mask = get_mask_from_sequence_lengths(char_lengths,
                                                       char_lengths.max())

            embeded_char_inputs = self.char_embedder(char_inputs)
            encoded_char_outputs, _ = self.char_encoder(embeded_char_inputs)
            char_attentions = masked_softmax(
                self.char_attention(encoded_char_outputs).squeeze(-1),
                char_mask,
                dim=-1)
            encoded_char_outputs = torch.bmm(
                encoded_char_outputs.permute(0, 2, 1),
                char_attentions.unsqueeze(-1))
            encoded_char_outputs = encoded_char_outputs.view(
                batch_size, seq_len, -1)
            embs.append(encoded_char_outputs)

        token_embedding = torch.cat(embs, dim=2)

        return self.projection(token_embedding)
Exemplo n.º 9
0
    def forward(self, question_and_answers: Dict[str, torch.LongTensor], video_features: Optional[torch.Tensor] = None,
                frame_count: Optional[torch.LongTensor] = None,
                label: Optional[torch.LongTensor] = None, **kwargs) -> Dict[str, torch.Tensor]:
        # This supposes a fixed number of answers, by grabbing any of the dict values available.
        num_answers = list(question_and_answers.values())[0].shape[1]

        if video_features is None:
            video_features_mask = None
        else:
            video_features = self._expand_to_num_answers(video_features, num_answers)
            video_features_mask = self._expand_to_num_answers(
                util.get_mask_from_sequence_lengths(frame_count, video_features.shape[2]), num_answers)

        embedded_question_and_answers = self.text_field_embedder(question_and_answers, num_wrapping_dims=1)
        question_and_answers_mask = util.get_text_field_mask(question_and_answers, num_wrapping_dims=1)

        scores = self.answer_scorer(video_features=video_features,
                                    video_features_mask=video_features_mask,
                                    embedded_question_and_answers=embedded_question_and_answers,
                                    question_and_answers_mask=question_and_answers_mask)

        output_dict = {'scores': scores}

        if label is not None:
            output_dict['loss'] = self.loss(scores, label)
            for metric in self.metrics.values():
                metric(scores, label)

        return output_dict
Exemplo n.º 10
0
 def test_get_mask_from_sequence_lengths(self):
     sequence_lengths = torch.LongTensor([4, 3, 1, 4, 2])
     mask = util.get_mask_from_sequence_lengths(sequence_lengths, 5).data.numpy()
     assert_almost_equal(mask, [[1, 1, 1, 1, 0],
                                [1, 1, 1, 0, 0],
                                [1, 0, 0, 0, 0],
                                [1, 1, 1, 1, 0],
                                [1, 1, 0, 0, 0]])
    def batchify(self, batch, device):
        sentence_len_s = [len(tup[0][1]) for tup in batch]

        max_sentence_len_s = self.max_len

        event1_lens = [len(tup[0][2]) for tup in batch]
        event2_lens = [len(tup[0][3]) for tup in batch]

        sentences_s, sentences_s_mask, event1, event2, data_y = list(), list(
        ), list(), list(), list()
        for data, data_mask in batch:
            sentences_s.append(data[1])
            sentences_s_mask.append(data_mask[1])
            event1.append(data[3])
            event2.append(data[4])
            y = self.y_label[data[5]]
            data_y.append(y)

        sentences_s = list(
            map(lambda x: pad_sequence_to_length(x, max_sentence_len_s),
                sentences_s))
        sentences_s_mask = list(
            map(lambda x: pad_sequence_to_length(x, max_sentence_len_s),
                sentences_s_mask))

        event1 = list(map(lambda x: pad_sequence_to_length(x, 5), event1))
        event2 = list(map(lambda x: pad_sequence_to_length(x, 5), event2))

        mask_sentences_s = get_mask_from_sequence_lengths(
            torch.LongTensor(sentence_len_s), max_sentence_len_s)

        mask_even1 = get_mask_from_sequence_lengths(
            torch.LongTensor(event1_lens), 5)
        mask_even2 = get_mask_from_sequence_lengths(
            torch.LongTensor(event2_lens), 5)

        return [
            torch.LongTensor(sentences_s).to(device),
            mask_sentences_s.to(device),
            torch.LongTensor(sentences_s_mask).to(device),
            torch.LongTensor(event1).to(device),
            mask_even1.to(device),
            torch.LongTensor(event2).to(device),
            mask_even2.to(device),
            torch.LongTensor(data_y).to(device)
        ]
Exemplo n.º 12
0
 def _encode(self, source_features: torch.FloatTensor,
             source_lengths: torch.LongTensor) -> Dict[str, torch.Tensor]:
     # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
     if self._cnn is not None:
         source_features, source_lengths = self._cnn(
             source_features, source_lengths)
     source_mask = util.get_mask_from_sequence_lengths(
         source_lengths, source_features.size(1))
     if self._conv_lstm is not None:
         source_features = self._conv_lstm(source_features, source_mask)
     if not isinstance(self._encoder, AWDRNN):
         encoder_outputs = self._encoder(source_features, source_mask)
     else:
         encoder_outputs, _, source_lengths = self._encoder(
             source_features, source_lengths, self._output_layer_num)
         source_mask = util.get_mask_from_sequence_lengths(
             source_lengths, encoder_outputs.size(1))
     # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
     return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}
    def position_mask(self) -> torch.Tensor:
        """
        Which elements are actual words in the sentence?
        :return: shape (batch_size, input_seq_len)
        """
        if not hasattr(self, "lengths"):
            self.lengths = torch.tensor([len(s) + 1 for s in self.sentences],
                                        device=get_device_id(self.constants))
            self.max_len = max(len(s) for s in self.sentences) + 1

        return get_mask_from_sequence_lengths(self.lengths, self.max_len)
Exemplo n.º 14
0
 def _encode(self, source_features: torch.FloatTensor,
             source_lengths: torch.LongTensor) -> Dict[str, torch.Tensor]:
     # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
     encoder_outputs, _, source_lengths = self._encoder(
         source_features, source_lengths)
     source_mask = util.get_mask_from_sequence_lengths(
         source_lengths, torch.max(source_lengths))
     # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
     return {
         "source_mask": source_mask,
         "encoder_outputs": encoder_outputs,
     }
Exemplo n.º 15
0
 def _get_phn_level_representations(
         self, features: torch.FloatTensor, mask: torch.BoolTensor,
         phn_log_probs: torch.Tensor) -> Dict[str, torch.Tensor]:
     phn_enc_outs, segment_lengths = averaging_tensor_of_same_label(
         features, phn_log_probs, mask=mask)
     state = {
         "encoder_outputs":
         phn_enc_outs,
         "source_mask":
         util.get_mask_from_sequence_lengths(segment_lengths,
                                             int(max(segment_lengths)))
     }
     return state
Exemplo n.º 16
0
    def _get_action_embeddings(
        state: NlvrDecoderState, actions_to_embed: List[List[int]]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        This method is identical to ``WikiTablesDecoderStep._get_action_embeddings``
        Returns an embedded representation for all actions in ``actions_to_embed``, using the state
        in ``NlvrDecoderState``.

        Parameters
        ----------
        state : ``NlvrDecoderState``
            The current state.  We'll use this to get the global action embeddings.
        actions_to_embed : ``List[List[int]]``
            A list of _global_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.

        Returns
        -------
        action_embeddings : ``torch.FloatTensor``
            An embedded representation of all of the given actions.  Shape is ``(group_size,
            num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of
            considered actions for any group element.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        """
        num_actions = [len(action_list) for action_list in actions_to_embed]
        max_num_actions = max(num_actions)
        padded_actions = [
            common_util.pad_sequence_to_length(action_list, max_num_actions)
            for action_list in actions_to_embed
        ]
        # Shape: (group_size, num_actions)
        action_tensor = Variable(
            state.score[0].data.new(padded_actions).long())
        # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim).
        # We want to select from state.action_embeddings using `action_tensor` to get a tensor of
        # shape (group_size, num_actions, action_embedding_dim).  Unfortunately, the index_select
        # functions in nn.util don't do this operation.  So we'll do some reshapes and do the
        # index_select ourselves.
        group_size = len(state.batch_indices)
        action_embedding_dim = state.action_embeddings.size(-1)
        flattened_actions = action_tensor.view(-1)
        flattened_action_embeddings = state.action_embeddings.index_select(
            0, flattened_actions)
        action_embeddings = flattened_action_embeddings.view(
            group_size, max_num_actions, action_embedding_dim)
        sequence_lengths = Variable(action_embeddings.data.new(num_actions))
        action_mask = nn_util.get_mask_from_sequence_lengths(
            sequence_lengths, max_num_actions)
        return action_embeddings, action_mask
Exemplo n.º 17
0
    def forward(self, input_: Tuple[torch.Tensor, torch.Tensor]):
        chars, lengths = input_
        batch_size, seq_len, max_chars = chars.size()

        chars = chars.view(batch_size * seq_len, -1)
        lengths = lengths.view(batch_size * seq_len)
        mask = get_mask_from_sequence_lengths(lengths, max_chars)
        chars = torch.autograd.Variable(chars, requires_grad=False)

        embeded_chars = self.embeddings(chars)
        output, _ = self.encoder_(embeded_chars)
        attentions = masked_softmax(self.attention(output).squeeze(-1), mask, dim=-1)
        output = torch.bmm(output.permute(0, 2, 1), attentions.unsqueeze(-1))

        return self.projection(output.view(batch_size, seq_len, -1))
Exemplo n.º 18
0
    def forward(self, input_: Tuple[torch.Tensor, torch.Tensor]):
        chars, lengths = input_
        batch_size, seq_len, max_chars = chars.size()

        chars = chars.view(batch_size * seq_len, -1)
        lengths = lengths.view(batch_size * seq_len)
        mask = get_mask_from_sequence_lengths(lengths, max_chars)
        chars = torch.autograd.Variable(chars, requires_grad=False)

        embeded_chars = self.embeddings(chars)
        output, _ = self.encoder_(embeded_chars)

        output = self.attention(output, mask).sum(dim=-2)

        return output.view(batch_size, seq_len, -1)
Exemplo n.º 19
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                context_span=None,
                gt_span=None,
                max_context_length=0,
                mode=ForwardMode.TRAIN):
        # Precomputing of the max_context_length is important
        # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible.
        sequence_output, _ = self.bert_encoder(input_ids,
                                               token_type_ids,
                                               attention_mask,
                                               output_all_encoded_layers=False)

        joint_seq_logits = self.qa_outputs(sequence_output)
        context_logits, context_length = span_util.span_select(
            joint_seq_logits, context_span, max_context_length)
        context_mask = allen_util.get_mask_from_sequence_lengths(
            context_length, max_context_length)

        # The following line is from AllenNLP bidaf.
        start_logits = allen_util.replace_masked_values(
            context_logits[:, :, 0], context_mask, -1e18)
        # B, T, 2
        end_logits = allen_util.replace_masked_values(context_logits[:, :, 1],
                                                      context_mask, -1e18)

        if mode == BertSpan.ForwardMode.TRAIN:
            assert gt_span is not None
            gt_start = gt_span[:, 0]  # gt_span: [B, 2]
            gt_end = gt_span[:, 1]

            start_loss = nll_loss(
                allen_util.masked_log_softmax(start_logits, context_mask),
                gt_start.squeeze(-1))
            end_loss = nll_loss(
                allen_util.masked_log_softmax(end_logits, context_mask),
                gt_end.squeeze(-1))

            loss = start_loss + end_loss
            return loss
        else:
            return start_logits, end_logits, context_length
Exemplo n.º 20
0
    def _get_action_embeddings(state: NlvrDecoderState,
                               actions_to_embed: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        This method is identical to ``WikiTablesDecoderStep._get_action_embeddings``
        Returns an embedded representation for all actions in ``actions_to_embed``, using the state
        in ``NlvrDecoderState``.

        Parameters
        ----------
        state : ``NlvrDecoderState``
            The current state.  We'll use this to get the global action embeddings.
        actions_to_embed : ``List[List[int]]``
            A list of _global_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.

        Returns
        -------
        action_embeddings : ``torch.FloatTensor``
            An embedded representation of all of the given actions.  Shape is ``(group_size,
            num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of
            considered actions for any group element.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        """
        num_actions = [len(action_list) for action_list in actions_to_embed]
        max_num_actions = max(num_actions)
        padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions)
                          for action_list in actions_to_embed]
        # Shape: (group_size, num_actions)
        action_tensor = state.score[0].new_tensor(padded_actions, dtype=torch.long)
        # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim).
        # We want to select from state.action_embeddings using `action_tensor` to get a tensor of
        # shape (group_size, num_actions, action_embedding_dim).  Unfortunately, the index_select
        # functions in nn.util don't do this operation.  So we'll do some reshapes and do the
        # index_select ourselves.
        group_size = len(state.batch_indices)
        action_embedding_dim = state.action_embeddings.size(-1)
        flattened_actions = action_tensor.view(-1)
        flattened_action_embeddings = state.action_embeddings.index_select(0, flattened_actions)
        action_embeddings = flattened_action_embeddings.view(group_size, max_num_actions, action_embedding_dim)
        sequence_lengths = action_embeddings.new_tensor(num_actions)
        action_mask = nn_util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions)
        return action_embeddings, action_mask
Exemplo n.º 21
0
    def forward(
        self,
        word_embs: torch.Tensor,  # Float[Batch, Word, Embedding]
        mask: torch.Tensor,  # Byte[Batch, Word]
        left: bool = False,
    ) -> torch.Tensor:  # Float[Batch, Embedding]

        device = word_embs.device

        lengths = mask.long().sum(dim=1).cpu().numpy()  # Long[Batch]
        sorted_lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(
            -lengths)  # sort descendingly w.r.t. length of sequence
        idx_unsort = np.argsort(idx_sort)  # get inverse permutation

        x_sorted = word_embs.index_select(
            0,
            torch.from_numpy(idx_sort).to(
                device=device))  # Float[Batch, Word, Embedding]

        x_packed = torch.nn.utils.rnn.pack_padded_sequence(
            x_sorted, lengths=sorted_lengths.copy(), batch_first=True)

        y_output, _ = self.lstm(x_packed)
        y_output, _ = torch.nn.utils.rnn.pad_packed_sequence(y_output,
                                                             batch_first=True)

        y_unsorted = y_output.index_select(
            0,
            torch.from_numpy(idx_unsort).to(
                device=device))  # Float[Batch, Word, Encoding]

        y_unsorted_inf = torch.where(
            get_mask_from_sequence_lengths(
                torch.tensor(lengths).to(device=device),
                max_length=y_unsorted.size(1)).unsqueeze(dim=2).expand(
                    -1, -1, y_unsorted.size(2)), y_unsorted,
            torch.ones_like(y_unsorted) * float('-inf'))
        pooled, _ = torch.max(y_unsorted_inf, dim=1)

        output = self.final_dropout(pooled)
        if self.with_linear_transform and left:
            output = self.linear(output)

        return output
Exemplo n.º 22
0
    def test_average_tensor_of_same_labels(self):
        batch_size = 10
        max_len = 16
        feat_dim = 32
        label_dim = 4
        for _ in range(10):
            phn_logits = torch.randn(batch_size, max_len, label_dim)
            phn_log_probs = F.log_softmax(phn_logits)
            lengths = torch.randint(label_dim, (batch_size, ))
            mask = get_mask_from_sequence_lengths(lengths, int(max(lengths)))
            enc_outs = torch.randn((batch_size, max_len, feat_dim))
            _, max_ids = phn_log_probs.max(dim=-1)
            phn_enc_out_list = []
            for b in range(batch_size):
                count = 1
                phn_enc_out = []
                feat = enc_outs[b, 0].clone()
                prev_id = None
                for t, max_id in enumerate(max_ids[b]):
                    if prev_id is None:
                        pass
                    elif max_id == prev_id:
                        feat += enc_outs[b, t].clone()
                        count += 1
                    else:
                        phn_enc_out.append(feat.div(count))
                        feat = enc_outs[b, t].clone()
                        count = 1
                    prev_id = max_id
                phn_enc_out.append(feat / float(count))
                phn_enc_out_list.append(phn_enc_out)
            phn_max_len = len(max(phn_enc_out_list, key=lambda x: len(x)))
            phn_enc_outs = enc_outs.new_zeros(batch_size, phn_max_len,
                                              feat_dim)
            for idx, phn_enc_out in enumerate(phn_enc_out_list):
                phn_enc_outs[idx, :len(phn_enc_out)] = torch.stack(phn_enc_out)

            len_phn_enc_outs, _ = averaging_tensor_of_same_label(
                enc_outs, phn_log_probs, lengths)
            torch.testing.assert_allclose(phn_enc_outs, len_phn_enc_outs)
            mask_phn_enc_outs, _ = averaging_tensor_of_same_label(
                enc_outs, phn_log_probs, mask)
            torch.testing.assert_allclose(phn_enc_outs, mask_phn_enc_outs)
Exemplo n.º 23
0
def pad_contextualizer_output(seqs: List[torch.Tensor]):
    """
    Takes the output of a contextualizer, a list (of length batch_size)
    of Tensors with shape (seq_len, repr_dim), and produces a padded
    Tensor with these possibly-variable length items of shape
    (batch_size, seq_len, repr_dim)

    Returns
    -------
    padded_representations: torch.FloatTensor
        FloatTensor of shape (batch_size, seq_len, repr_dim) with 0 padding.
    mask: torch.FloatTensor
        A (batch_size, max_length) mask with 1's in positions without padding
        and 0's in positions with padding.
    """
    lengths = [len(s) for s in seqs]
    max_len = max(lengths)
    mask = get_mask_from_sequence_lengths(seqs[0].new_tensor(lengths), max_len)
    return torch.stack([
        torch.cat([s, s.new_zeros(max_len - len_, s.size(-1))], dim=0)
        for s, len_ in zip(seqs, lengths)
    ]), mask
Exemplo n.º 24
0
    def forward(
            self,  # type: ignore
            utterance: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            world: List[SpiderWorld],
            schema: Dict[str, torch.LongTensor],
            action_sequence: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:

        max_len_entities = max(
            [len(w.db_context.knowledge_graph.entities) for w in world])
        batch_size = len(world)
        device = utterance['tokens'].device

        oracle_entities = []
        oracle_relevance_score = None
        if action_sequence is not None:
            # we want oracle supervision for which entities should be in the query, for the loss calculation
            for batch_index, batch_actions in enumerate(
                    action_sequence.squeeze(-1)):
                oracle_entities.append(
                    set([
                        valid_actions[batch_index][action][0].split(
                            ' -> ')[1].strip('["]') for action in batch_actions
                        if not valid_actions[batch_index][action][1]
                        and action >= 0
                    ]))
            oracle_relevance_score = [
                pad_sequence_to_length(w.get_oracle_relevance_score(oe),
                                       max_len_entities)
                for w, oe in zip(world, oracle_entities)
            ]
            oracle_relevance_score = torch.tensor(oracle_relevance_score,
                                                  dtype=torch.float,
                                                  device=device)

        initial_state = self._get_initial_state(utterance, world, schema,
                                                valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            action_mask = action_sequence != self._action_padding_index
        else:
            action_mask = None

        self.graph_mask = util.get_mask_from_sequence_lengths(
            torch.tensor([len(w.entities_names) for w in world],
                         device=device), max_len_entities).float()

        loss = torch.tensor([0]).float().to(device)

        if action_sequence is not None:
            graph_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                self.predicted_relevance_logits.squeeze(-1),
                oracle_relevance_score,
                reduction='none')
            graph_loss = (graph_loss *
                          self.graph_mask).sum() / self.graph_mask.sum()

            graph_loss *= self._graph_loss_lambda

            loss += graph_loss

        if self.training:
            try:
                decode_output = self._decoder_trainer.decode(
                    initial_state, self._transition_function,
                    (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))
                query_loss = decode_output['loss']
            except ZeroDivisionError:
                return {
                    'loss':
                    Parameter(torch.tensor([0]).float()).to(
                        action_sequence.device)
                }

            loss += ((1 - self._graph_loss_lambda) * query_loss)

            return {'loss': loss}
        else:
            if action_sequence is not None and action_sequence.size(1) > 1:
                try:
                    query_loss = self._decoder_trainer.decode(
                        initial_state, self._transition_function,
                        (action_sequence.unsqueeze(1),
                         action_mask.unsqueeze(1)))['loss']
                    loss += query_loss
                except ZeroDivisionError:
                    pass

            outputs: Dict[str, Any] = {'loss': loss}

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]

            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False)

            self._compute_validation_outputs(valid_actions, best_final_states,
                                             world, action_sequence, outputs)
            return outputs
Exemplo n.º 25
0
def span_pruner(embeddings,
                scores,
                mask,
                seq_length,
                spans_per_word=1,
                num_keep=None):
    """

        Based on AllenNLP allennlp.modules.Pruner from release 0.84


        Parameters
        ----------

        logits: (batch_size, num_spans, num_tags)
        mask: (batch_size, num_spans)
        num_keep: int OR torch.LongTensor
                If a tensor of shape (batch_size), specifies the
                number of items to keep for each
                individual sentence in minibatch.
                If an int, keep the same number of items for all sentences.


        """

    #batch_size, num_items, num_tags = tuple(logits.shape)
    batch_size, num_items = tuple(scores.shape)

    # Number to keep not provided, so use spans per word
    if num_keep is None:
        num_keep = seq_length * spans_per_word
        num_keep = torch.max(num_keep, torch.ones_like(num_keep))

    # If an int was given for number of items to keep, construct tensor by repeating the value.
    if isinstance(num_keep, int):
        num_keep = num_keep * torch.ones(
            [batch_size], dtype=torch.long, device=mask.device)

    # Maximum number to keep
    max_keep = num_keep.max()

    # Get scores from logits
    # (batch_size, num_spans)
    # scores = logit_scorer(logits)

    # Set overlapping span scores large neg number
    #if prune_overlapping:
    #    scores = overlap_filter(scores, span_overlaps)

    # Add dimension
    scores = scores.unsqueeze(-1)
    #embeddings = embeddings.unsqueeze(-1)

    # Check scores dimensionality
    if scores.size(-1) != 1 or scores.dim() != 3:
        raise ValueError(
            f"The scorer passed to Pruner must produce a tensor of shape"
            f"(batch_size, num_items, 1), but found shape {scores.size()}")

    # Make sure that we don't select any masked items by setting their scores to be very
    # negative.  These are logits, typically, so -1e20 should be plenty negative.
    #print("scores", scores.shape)
    #print('mask', mask.shape)
    mask = mask.unsqueeze(-1).bool()  #type(torch.BoolTensor)
    #print('mask', mask.shape, mask.type)
    scores = util.replace_masked_values(scores, mask, NEG_FILL)

    # Shape: (batch_size, max_num_items_to_keep, 1)
    _, top_indices = scores.topk(max_keep, 1)

    # Mask based on number of items to keep for each sentence.
    # Shape: (batch_size, max_num_items_to_keep)
    top_indices_mask = util.get_mask_from_sequence_lengths(num_keep, max_keep)
    top_indices_mask = top_indices_mask.bool()

    # Shape: (batch_size, max_num_items_to_keep)
    top_indices = top_indices.squeeze(-1)

    # Fill all masked indices with largest "top" index for that sentence, so that all masked
    # indices will be sorted to the end.
    # Shape: (batch_size, 1)
    fill_value, _ = top_indices.max(dim=1)
    fill_value = fill_value.unsqueeze(-1)
    # Shape: (batch_size, max_num_items_to_keep)
    top_indices = torch.where(top_indices_mask, top_indices, fill_value)
    # Now we order the selected indices in increasing order with
    # respect to their indices (and hence, with respect to the
    # order they originally appeared in the ``embeddings`` tensor).
    top_indices, _ = torch.sort(top_indices, 1)

    # Shape: (batch_size * max_num_items_to_keep)
    # torch.index_select only accepts 1D indices, but here
    # we need to select items for each element in the batch.
    flat_indices = util.flatten_and_batch_shift_indices(top_indices, num_items)

    # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside
    # the top k for each sentence.
    # Shape: (batch_size, max_num_items_to_keep)
    sequence_mask = util.batched_index_select(mask, top_indices, flat_indices)
    sequence_mask = sequence_mask.squeeze(-1).bool()
    top_mask = top_indices_mask & sequence_mask
    top_mask = top_mask.long()

    # Shape: (batch_size, max_num_items_to_keep, 1)
    top_scores = util.batched_index_select(scores, top_indices, flat_indices)
    top_embeddings = util.batched_index_select(embeddings, top_indices,
                                               flat_indices)

    # Shape: (batch_size, max_num_items_to_keep)
    top_scores = top_scores.squeeze(-1)
    #top_embeddings = top_embeddings.squeeze(-1)

    return (top_indices, top_embeddings, top_scores, top_mask)
Exemplo n.º 26
0
    def forward(
        self,
        sentence: torch.Tensor,  # R[Batch, Word, Emb]
        sentence_lengths: torch.Tensor,  # Z_Word[Batch]
        span: torch.Tensor,  # R[Batch, Word, Emb]
        span_lengths: torch.Tensor,  # Z_Word[Batch]
        span_left: torch.Tensor,  # Z_Word[Batch]
        span_right: torch.Tensor  # Z_Word[Batch]
    ) -> torch.Tensor:  # R[Batch, Feature]

        batch_size = sentence.size(0)
        sentence_max_len = sentence.size(1)
        emb_size = sentence.size(2)
        span_max_len = span.size(1)
        device = sentence.device
        neg_inf = torch.tensor(-10000, dtype=torch.float32, device=device)
        zero = torch.tensor(0, dtype=torch.float32, device=device)

        span = self.projection(self.dropout(span))
        sentence = self.projection(self.dropout(sentence))

        span_mask = get_mask_from_sequence_lengths(
            span_lengths,
            span_lengths.max().item()).byte()  # Z[Batch, Word]

        def attention_pool():
            span_attn_scores = torch.einsum('e,bwe->bw', self.query, span)
            masked_span_attn_scores = torch.where(span_mask, span_attn_scores,
                                                  neg_inf)
            normalized_span_attn_scores = F.softmax(masked_span_attn_scores,
                                                    dim=1)
            span_pooled = torch.einsum('bwe,bw->be', span,
                                       normalized_span_attn_scores)
            return span_pooled

        span_pooled = {
            "max":
            lambda: torch.max(torch.where(
                span_mask.unsqueeze(dim=2).expand_as(span), span, neg_inf),
                              dim=1)[0],
            "mean":
            lambda: torch.sum(torch.where(
                span_mask.unsqueeze(dim=2).expand_as(span), span, zero),
                              dim=1) / span_lengths.unsqueeze(dim=1).expand(
                                  batch_size, emb_size),
            "attention":
            lambda: attention_pool()
        }[self.mention_pooling]()  # R[Batch, Emb]

        features = span_pooled

        if self.with_context:
            sentence_mask = get_mask_from_sequence_lengths(
                sentence_lengths, sentence_max_len).bool()  # B[B, L]

            length_range = torch.arange(0, sentence_max_len, device=device) \
                .unsqueeze(dim=0).expand(batch_size, sentence_max_len)
            span_mask = (length_range >= (span_left.unsqueeze(dim=1).expand_as(length_range))) \
                & (length_range < (span_right.unsqueeze(dim=1).expand_as(length_range)))  # B[Batch, Length]

            span_queries = self.mention_query_transform(span_pooled)
            attn_scores = torch.einsum('be,bwe->bw', span_queries,
                                       sentence)  # R[Batch, Word]
            masked_attn_scores = torch.where(
                sentence_mask, attn_scores,
                neg_inf)  # R[Batch, Word]  & ~span_mask
            normalized_attn_scores = F.softmax(masked_attn_scores, dim=1)
            context_pooled = torch.einsum(
                'bwe,bw->be', sentence,
                normalized_attn_scores)  # R[Batch, Emb]

            features = torch.cat([span_pooled, context_pooled],
                                 dim=1)  # R[Batch, Emb*2]

        return features  # R[Batch, Emb]
Exemplo n.º 27
0
    def forward(
            self,
            context_ids: TextFieldTensors,
            query_ids: TextFieldTensors,
            context_lens: torch.Tensor,
            query_lens: torch.Tensor,
            mask_label: Optional[torch.Tensor] = None,
            cls_label: Optional[torch.Tensor] = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # concat the context and query to the encoder
        # get the indexers first
        indexers = context_ids.keys()
        dialogue_ids = {}

        # 获取context和query的长度
        context_len = torch.max(context_lens).item()
        query_len = torch.max(query_lens).item()

        # [B, _len]
        context_mask = get_mask_from_sequence_lengths(context_lens,
                                                      context_len)
        query_mask = get_mask_from_sequence_lengths(query_lens, query_len)
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # get the outputs of the dialogue
        if isinstance(self._text_field_embedder, TextFieldEmbedder):
            embedder_outputs = self._text_field_embedder(dialogue_ids)
        else:
            embedder_outputs = self._text_field_embedder(
                **dialogue_ids[self._index_name])

        # get the outputs of the query and context
        # [B, _len, embed_size]
        context_last_layer = embedder_outputs[:, :context_len].contiguous()
        query_last_layer = embedder_outputs[:, context_len:].contiguous()

        output_dict = {}
        # --------- cls任务:判断是否需要改写 ------------------
        if self._cls_task:
            # 获取cls表征, [B, embed_size]
            cls_embed = context_last_layer[:, 0]
            # 经过线性层分类, [B, 2]
            cls_logits = self._cls_linear(cls_embed)
            output_dict["cls_logits"] = cls_logits
        else:
            cls_logits = None

        # --------- mask任务:判断query中需要填充的位置 -----------
        if self._mask_task:
            # 经过线性层,[B, _len, 2]
            mask_logits = self._mask_linear(query_last_layer)
            output_dict["mask_logits"] = mask_logits
        else:
            mask_logits = None

        if cls_label is not None:
            output_dict["loss"] = self._calc_loss(cls_label, mask_label,
                                                  cls_logits, mask_logits,
                                                  query_mask)

        return output_dict
Exemplo n.º 28
0
    def _get_action_embeddings(state: WikiTablesDecoderState,
                               actions_to_embed: List[List[int]]) -> Tuple[torch.Tensor,
                                                                           torch.Tensor,
                                                                           torch.Tensor,
                                                                           torch.Tensor]:
        """
        Returns an embedded representation for all actions in ``actions_to_embed``, using the state
        in ``WikiTablesDecoderState``.

        Parameters
        ----------
        state : ``WikiTablesDecoderState``
            The current state.  We'll use this to get the global action embeddings.
        actions_to_embed : ``List[List[int]]``
            A list of _global_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.  This is expected to be output from
            :func:`_get_actions_to_consider`.

        Returns
        -------
        action_embeddings : ``torch.FloatTensor``
            An embedded representation of all of the given actions.  Shape is ``(group_size,
            num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of
            considered actions for any group element.
        output_action_embeddings : ``torch.FloatTensor``
            A second embedded representation of all of the given actions.  The first is used when
            selecting actions, the second is used as the decoder output (which is the input at the
            next timestep).  This is similar to having separate word embeddings and softmax layer
            weights in a language model or MT model.
        action_biases : ``torch.FloatTensor``
            A bias weight for predicting each action.  Shape is ``(group_size, num_actions, 1)``.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        """
        num_actions = [len(action_list) for action_list in actions_to_embed]
        max_num_actions = max(num_actions)
        padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions)
                          for action_list in actions_to_embed]
        # Shape: (group_size, num_actions)
        action_tensor = Variable(state.score[0].data.new(padded_actions).long())
        # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim).
        # We want to select from state.action_embeddings using `action_tensor` to get a tensor of
        # shape (group_size, num_actions, action_embedding_dim).  Unfortunately, the index_select
        # functions in nn.util don't do this operation.  So we'll do some reshapes and do the
        # index_select ourselves.
        group_size = len(state.batch_indices)
        action_embedding_dim = state.action_embeddings.size(-1)

        flattened_actions = action_tensor.view(-1)
        flattened_action_embeddings = state.action_embeddings.index_select(0, flattened_actions)
        action_embeddings = flattened_action_embeddings.view(group_size, max_num_actions, action_embedding_dim)

        flattened_output_embeddings = state.output_action_embeddings.index_select(0, flattened_actions)
        output_embeddings = flattened_output_embeddings.view(group_size, max_num_actions, action_embedding_dim)

        flattened_biases = state.action_biases.index_select(0, flattened_actions)
        biases = flattened_biases.view(group_size, max_num_actions, 1)

        sequence_lengths = Variable(action_embeddings.data.new(num_actions))
        action_mask = util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions)
        return action_embeddings, output_embeddings, biases, action_mask
Exemplo n.º 29
0
    def _get_entity_action_logits(self,
                                  state: WikiTablesDecoderState,
                                  actions_to_link: List[List[int]],
                                  attention_weights: torch.Tensor,
                                  linked_checklist_balance: torch.Tensor = None) -> Tuple[torch.FloatTensor,
                                                                                          torch.LongTensor,
                                                                                          torch.FloatTensor]:
        """
        Returns scores for each action in ``actions_to_link`` that are derived from the linking
        scores between the question and the table entities, and the current attention on the
        question.  The intuition is that if we're paying attention to a particular word in the
        question, we should tend to select entity productions that we think that word refers to.
        We additionally return a mask representing which elements in the returned ``action_logits``
        tensor are just padding, and an embedded representation of each action that can be used as
        input to the next step of the encoder.  That embedded representation is derived from the
        type of the entity produced by the action.

        The ``actions_to_link`` are in terms of the `batch` action list passed to
        ``model.forward()``.  We need to convert these integers into indices into the linking score
        tensor, which has shape (batch_size, num_entities, num_question_tokens), look up the
        linking score for each entity, then aggregate the scores using the current question
        attention.

        Parameters
        ----------
        state : ``WikiTablesDecoderState``
            The current state.  We'll use this to get the linking scores.
        actions_to_link : ``List[List[int]]``
            A list of _batch_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.  This is expected to be output from
            :func:`_get_actions_to_consider`.
        attention_weights : ``torch.Tensor``
            The current attention weights over the question tokens.  Should have shape
            ``(group_size, num_question_tokens)``.
        linked_checklist_balance : ``torch.Tensor``, optional (default=None)
            If the parser is being trained to maximize coverage over an agenda, this is the balance
            vector corresponding to entity actions, containing 1s and 0s, with 1s showing the
            actions that are yet to be produced. Required only if the parser is being trained to
            maximize coverage.

        Returns
        -------
        action_logits : ``torch.FloatTensor``
            A score for each of the given actions.  Shape is ``(group_size, num_actions)``, where
            ``num_actions`` is the maximum number of considered actions for any group element.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        type_embeddings : ``torch.LongTensor``
            A tensor of shape ``(group_size, num_actions, action_embedding_dim)``, with an embedded
            representation of the `type` of the entity corresponding to each action.
        """
        # First we map the actions to entity indices, using state.actions_to_entities, and find the
        # type of each entity using state.entity_types.
        action_entities: List[List[int]] = []
        entity_types: List[List[int]] = []
        for batch_index, action_list in zip(state.batch_indices, actions_to_link):
            action_entities.append([])
            entity_types.append([])
            for action_index in action_list:
                entity_index = state.actions_to_entities[(batch_index, action_index)]
                action_entities[-1].append(entity_index)
                entity_types[-1].append(state.entity_types[entity_index])

        # Then we create a padded tensor suitable for use with
        # `state.flattened_linking_scores.index_select()`.
        num_actions = [len(action_list) for action_list in action_entities]
        max_num_actions = max(num_actions)
        padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions)
                          for action_list in action_entities]
        padded_types = [common_util.pad_sequence_to_length(type_list, max_num_actions)
                        for type_list in entity_types]
        # Shape: (group_size, num_actions)
        action_tensor = state.score[0].new_tensor(padded_actions, dtype=torch.long)
        type_tensor = state.score[0].new_tensor(padded_types, dtype=torch.long)

        # To get the type embedding tensor, we just use an embedding matrix on the list of entity
        # types.
        type_embeddings = self._entity_type_embedding(type_tensor)
        # `state.flattened_linking_scores` is shape (batch_size * num_entities, num_question_tokens).
        # We want to select from this using `action_tensor` to get a tensor of shape (group_size,
        # num_actions, num_question_tokens).  Unfortunately, the index_select functions in nn.util
        # don't do this operation.  So we'll do some reshapes and do the index_select ourselves.
        group_size = len(state.batch_indices)
        num_question_tokens = state.flattened_linking_scores.size(-1)
        flattened_actions = action_tensor.view(-1)
        # (group_size * num_actions, num_question_tokens)
        flattened_action_linking = state.flattened_linking_scores.index_select(0, flattened_actions)
        # (group_size, num_actions, num_question_tokens)
        action_linking = flattened_action_linking.view(group_size, max_num_actions, num_question_tokens)

        # Now we get action logits by weighting these entity x token scores by the attention over
        # the question tokens.  We can do this efficiently with torch.bmm.
        action_logits = action_linking.bmm(attention_weights.unsqueeze(-1)).squeeze(-1)
        if linked_checklist_balance is not None:
            # ``linked_checklist_balance`` is a binary tensor of size (group_size, num_actions) with
            # 1s indicating the linked actions that the agenda wants the decoder to produce, but
            # haven't been produced yet. We're simply doubling the logits of those actions here.
            action_logits_addition = action_logits * linked_checklist_balance
            action_logits = action_logits + self._linked_checklist_multiplier * action_logits_addition
        # Finally, we make a mask for our action logit tensor.
        sequence_lengths = action_linking.new_tensor(num_actions)
        action_mask = util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions)
        return action_logits, action_mask, type_embeddings
Exemplo n.º 30
0
    def _get_action_embeddings(state: WikiTablesDecoderState,
                               actions_to_embed: List[List[int]]) -> Tuple[torch.Tensor,
                                                                           torch.Tensor,
                                                                           torch.Tensor,
                                                                           torch.Tensor]:
        """
        Returns an embedded representation for all actions in ``actions_to_embed``, using the state
        in ``WikiTablesDecoderState``.

        Parameters
        ----------
        state : ``WikiTablesDecoderState``
            The current state.  We'll use this to get the global action embeddings.
        actions_to_embed : ``List[List[int]]``
            A list of _global_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.  This is expected to be output from
            :func:`_get_actions_to_consider`.

        Returns
        -------
        action_embeddings : ``torch.FloatTensor``
            An embedded representation of all of the given actions.  Shape is ``(group_size,
            num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of
            considered actions for any group element.
        output_action_embeddings : ``torch.FloatTensor``
            A second embedded representation of all of the given actions.  The first is used when
            selecting actions, the second is used as the decoder output (which is the input at the
            next timestep).  This is similar to having separate word embeddings and softmax layer
            weights in a language model or MT model.
        action_biases : ``torch.FloatTensor``
            A bias weight for predicting each action.  Shape is ``(group_size, num_actions, 1)``.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        """
        num_actions = [len(action_list) for action_list in actions_to_embed]
        max_num_actions = max(num_actions)
        padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions)
                          for action_list in actions_to_embed]
        # Shape: (group_size, num_actions)
        action_tensor = state.score[0].new_tensor(padded_actions, dtype=torch.long)
        # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim).
        # We want to select from state.action_embeddings using `action_tensor` to get a tensor of
        # shape (group_size, num_actions, action_embedding_dim).  Unfortunately, the index_select
        # functions in nn.util don't do this operation.  So we'll do some reshapes and do the
        # index_select ourselves.
        group_size = len(state.batch_indices)
        action_embedding_dim = state.action_embeddings.size(-1)

        flattened_actions = action_tensor.view(-1)
        flattened_action_embeddings = state.action_embeddings.index_select(0, flattened_actions)
        action_embeddings = flattened_action_embeddings.view(group_size, max_num_actions, action_embedding_dim)

        flattened_output_embeddings = state.output_action_embeddings.index_select(0, flattened_actions)
        output_embeddings = flattened_output_embeddings.view(group_size, max_num_actions, action_embedding_dim)

        flattened_biases = state.action_biases.index_select(0, flattened_actions)
        biases = flattened_biases.view(group_size, max_num_actions, 1)

        sequence_lengths = action_embeddings.new_tensor(num_actions)
        action_mask = util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions)
        return action_embeddings, output_embeddings, biases, action_mask
    def forward(self,
                inputs,
                mask,
                sent_counts,
                sent_lens,
                prompt_inputs,
                prompt_mask,
                prompt_sent_counts,
                prompt_sent_lens,
                manual_feature,
                label=None):
        """

        :param prompt_sent_lens:
        :param prompt_sent_counts:
        :param prompt_inputs:
        :param prompt_mask:
        :param inputs:  [batch size, max sent count, max sent len]
        :param mask:    [batch size, max sent count, max sent len]
        :param sent_counts: [batch size]
        :param sent_lens: [batch size, max sent count]
        :param label: [batch size]
        :return:
        """
        batch_size = inputs.shape[0]
        max_sent_count = inputs.shape[1]
        max_sent_length = inputs.shape[2]

        inputs = inputs.view(-1, inputs.shape[-1])
        mask = mask.view(-1, mask.shape[-1])

        # [batch size * max sent len, hid size]
        last_hidden_states = self.bert(input_ids=inputs,
                                       attention_mask=mask)[0]
        last_hidden_states = last_hidden_states.view(batch_size,
                                                     max_sent_count,
                                                     max_sent_length, -1)

        prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1])
        prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1])
        prompt_hidden_states = self.bert(input_ids=prompt_inputs,
                                         attention_mask=prompt_mask)[0]

        docs = []
        lens = []
        for i in range(0, batch_size):
            doc = []
            sent_count = sent_counts[i]
            sent_len = sent_lens[i]

            for j in range(sent_count):
                length = sent_len[j]
                cur_sent = last_hidden_states[i, j, :length, :]
                # print('cur sent shape', cur_sent.shape)
                doc.append(cur_sent)

            # mean for a doc
            doc_vec = torch.cat(doc, dim=0).unsqueeze(0)
            doc_vec = self.positional_encoding.forward(doc_vec)

            lens.append(doc_vec.shape[1])
            # print(i, 'doc shape', doc_vec.shape)
            docs.append(doc_vec)

        batch_max_len = max(lens)
        for i, doc in enumerate(docs):
            if doc.shape[1] < batch_max_len:
                pd = (0, 0, 0, batch_max_len - doc.shape[1])
                m = nn.ConstantPad2d(pd, 0)
                doc = m(doc)

            docs[i] = doc

        # [batch size, bert embedding dim]
        docs = torch.cat(docs, 0)
        docs_mask = get_mask_from_sequence_lengths(
            torch.tensor(lens), max_length=batch_max_len).to(docs.device)

        prompt = []
        for j in range(prompt_sent_counts):
            length = prompt_sent_lens[0][j]
            sent = prompt_hidden_states[j, :length, :]
            prompt.append(sent)

        prompt_vec = torch.cat(prompt, dim=0).unsqueeze(0)
        prompt_vec = self.positional_encoding.forward(prompt_vec)
        prompt_len = prompt_vec.shape[1]
        prompt_attention_mask = get_mask_from_sequence_lengths(
            torch.tensor([prompt_len]),
            max_length=prompt_len).to(prompt_vec.device)
        # [1, seq len]
        prompt_vec_weights = self.prompt_global_attention(
            prompt_vec, prompt_attention_mask)
        # [1, bert hidden size]
        prompt_vec = torch.bmm(prompt_vec_weights.unsqueeze(1),
                               prompt_vec).squeeze(1)

        doc_weights = self.doc_global_attention(docs, docs_mask)
        doc_vec = torch.bmm(doc_weights.unsqueeze(1), docs).squeeze(1)

        doc_feature = self.dropout_layer(torch.tanh(doc_vec))
        prompt_feature = self.dropout_layer(
            torch.tanh(prompt_vec.expand_as(doc_feature)))
        feature = torch.cat([doc_feature, prompt_feature], dim=-1)

        log_probs = torch.log_softmax(self.linear_layer(feature), dim=-1)

        # log_probs = self.classifier(docs)
        if label is not None:
            loss = self.criterion(input=log_probs.contiguous().view(
                -1, log_probs.shape[-1]),
                                  target=label.contiguous().view(-1))
        else:
            loss = None

        prediction = torch.max(log_probs, dim=1)[1]
        return {'loss': loss, 'prediction': prediction}
Exemplo n.º 32
0
    def forward(
        self,  # pylint: disable=arguments-differ
        embeddings: torch.FloatTensor,
        mask: torch.LongTensor,
        num_items_to_keep: Union[int, torch.LongTensor]
    ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor,
               torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model). May use the same k for all sentences in
        minibatch, or different k for each.

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``Union[int, torch.LongTensor]``, required.
            If a tensor of shape (batch_size), specifies the number of items to keep for each
            individual sentence in minibatch.
            If an int, keep the same number of items for all sentences.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring items.
            Has shape (batch_size, max_num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The corresponding mask for ``top_embeddings``.
            Has shape (batch_size, max_num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, max_num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, max_num_items_to_keep, 1).
        """
        # If an int was given for number of items to keep, construct tensor by repeating the value.
        if isinstance(num_items_to_keep, int):
            batch_size = mask.size(0)
            # Put the tensor on same device as the mask.
            num_items_to_keep = num_items_to_keep * torch.ones(
                [batch_size], dtype=torch.long, device=mask.device)

        max_items_to_keep = num_items_to_keep.max()

        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)
        # Shape: (batch_size, num_items, 1)
        scores = self._scorer(embeddings)

        if scores.size(-1) != 1 or scores.dim() != 3:
            raise ValueError(
                f"The scorer passed to Pruner must produce a tensor of shape"
                f"(batch_size, num_items, 1), but found shape {scores.size()}")
        # Make sure that we don't select any masked items by setting their scores to be very
        # negative.  These are logits, typically, so -1e20 should be plenty negative.
        scores = util.replace_masked_values(scores, mask, -1e20)

        # Shape: (batch_size, max_num_items_to_keep, 1)
        _, top_indices = scores.topk(max_items_to_keep, 1)

        # Mask based on number of items to keep for each sentence.
        # Shape: (batch_size, max_num_items_to_keep)
        top_indices_mask = util.get_mask_from_sequence_lengths(
            num_items_to_keep, max_items_to_keep)
        top_indices_mask = top_indices_mask.byte()

        # Shape: (batch_size, max_num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Fill all masked indices with largest "top" index for that sentence, so that all masked
        # indices will be sorted to the end.
        # Shape: (batch_size, 1)
        fill_value, _ = top_indices.max(dim=1)
        fill_value = fill_value.unsqueeze(-1)
        # Shape: (batch_size, max_num_items_to_keep)
        top_indices = torch.where(top_indices_mask, top_indices, fill_value)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size * max_num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(
            top_indices, num_items)

        # Shape: (batch_size, max_num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices,
                                                   flat_top_indices)

        # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside
        # the top k for each sentence.
        # Shape: (batch_size, max_num_items_to_keep)
        sequence_mask = util.batched_index_select(mask, top_indices,
                                                  flat_top_indices)
        sequence_mask = sequence_mask.squeeze(-1).byte()
        top_mask = top_indices_mask & sequence_mask
        top_mask = top_mask.long()

        # Shape: (batch_size, max_num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices,
                                               flat_top_indices)

        return top_embeddings, top_mask, top_indices, top_scores
Exemplo n.º 33
0
    def _get_entity_action_logits(self,
                                  state: WikiTablesDecoderState,
                                  actions_to_link: List[List[int]],
                                  attention_weights: torch.Tensor) -> Tuple[torch.FloatTensor,
                                                                            torch.LongTensor,
                                                                            torch.FloatTensor]:
        """
        Returns scores for each action in ``actions_to_link`` that are derived from the linking
        scores between the question and the table entities, and the current attention on the
        question.  The intuition is that if we're paying attention to a particular word in the
        question, we should tend to select entity productions that we think that word refers to.
        We additionally return a mask representing which elements in the returned ``action_logits``
        tensor are just padding, and an embedded representation of each action that can be used as
        input to the next step of the encoder.  That embedded representation is derived from the
        type of the entity produced by the action.

        The ``actions_to_link`` are in terms of the `batch` action list passed to
        ``model.forward()``.  We need to convert these integers into indices into the linking score
        tensor, which has shape (batch_size, num_entities, num_question_tokens), look up the
        linking score for each entity, then aggregate the scores using the current question
        attention.

        Parameters
        ----------
        state : ``WikiTablesDecoderState``
            The current state.  We'll use this to get the linking scores.
        actions_to_link : ``List[List[int]]``
            A list of _batch_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.  This is expected to be output from
            :func:`_get_actions_to_consider`.
        attention_weights : ``torch.Tensor``
            The current attention weights over the question tokens.  Should have shape
            ``(group_size, num_question_tokens)``.

        Returns
        -------
        action_logits : ``torch.FloatTensor``
            A score for each of the given actions.  Shape is ``(group_size, num_actions)``, where
            ``num_actions`` is the maximum number of considered actions for any group element.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        type_embeddings : ``torch.LongTensor``
            A tensor of shape ``(group_size, num_actions, action_embedding_dim)``, with an embedded
            representation of the `type` of the entity corresponding to each action.
        """
        # First we map the actions to entity indices, using state.actions_to_entities, and find the
        # type of each entity using state.entity_types.
        action_entities: List[List[int]] = []
        entity_types: List[List[int]] = []
        for batch_index, action_list in zip(state.batch_indices, actions_to_link):
            action_entities.append([])
            entity_types.append([])
            for action_index in action_list:
                entity_index = state.actions_to_entities[(batch_index, action_index)]
                action_entities[-1].append(entity_index)
                entity_types[-1].append(state.entity_types[entity_index])

        # Then we create a padded tensor suitable for use with
        # `state.flattened_linking_scores.index_select()`.
        num_actions = [len(action_list) for action_list in action_entities]
        max_num_actions = max(num_actions)
        padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions)
                          for action_list in action_entities]
        padded_types = [common_util.pad_sequence_to_length(type_list, max_num_actions)
                        for type_list in entity_types]
        # Shape: (group_size, num_actions)
        action_tensor = Variable(state.score[0].data.new(padded_actions).long())
        type_tensor = Variable(state.score[0].data.new(padded_types).long())

        # To get the type embedding tensor, we just use an embedding matrix on the list of entity
        # types.
        type_embeddings = self._entity_type_embedding(type_tensor)

        # `state.flattened_linking_scores` is shape (batch_size * num_entities, num_question_tokens).
        # We want to select from this using `action_tensor` to get a tensor of shape (group_size,
        # num_actions, num_question_tokens).  Unfortunately, the index_select functions in nn.util
        # don't do this operation.  So we'll do some reshapes and do the index_select ourselves.
        group_size = len(state.batch_indices)
        num_question_tokens = state.flattened_linking_scores.size(-1)
        flattened_actions = action_tensor.view(-1)
        # (group_size * num_actions, num_question_tokens)
        flattened_action_linking = state.flattened_linking_scores.index_select(0, flattened_actions)
        # (group_size, num_actions, num_question_tokens)
        action_linking = flattened_action_linking.view(group_size, max_num_actions, num_question_tokens)

        # Now we get action logits by weighting these entity x token scores by the attention over
        # the question tokens.  We can do this efficiently with torch.bmm.
        action_logits = action_linking.bmm(attention_weights.unsqueeze(-1)).squeeze(-1)

        # Finally, we make a mask for our action logit tensor.
        sequence_lengths = Variable(action_linking.data.new(num_actions))
        action_mask = util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions)
        return action_logits, action_mask, type_embeddings
Exemplo n.º 34
0
    def forward(
            self,  # type: ignore
            source_features: torch.FloatTensor,
            source_lengths: torch.LongTensor,
            target_tokens: Dict[str, torch.LongTensor] = None,
            words: Dict[str, torch.LongTensor] = None,
            segments: torch.LongTensor = None,
            pos_tags: torch.LongTensor = None,
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None,
            epoch_num: int = None,
            dataset: str = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        output_dict = {}
        if dataset is not None:
            self._target_granularity = dataset[0]

        if epoch_num is not None:
            self._epoch_num = epoch_num[0]
        self.set_output_layer_num()

        source_mask = util.get_mask_from_sequence_lengths(
            source_lengths, source_features.size(1)).bool()

        source_features = source_features.unsqueeze(1)  # make a channel dim
        if self._delta:
            source_features = self._delta(source_features)

        batch_size, n_channels, timesteps, feature_size = source_features.size(
        )
        source_features = self._input_norm(
            source_features.transpose(-2, -1).reshape(batch_size, -1, timesteps)) \
            .view(batch_size, n_channels, feature_size, timesteps).transpose(-2, -1)
        source_features = self.time_mask(source_features, source_mask)
        source_features = self.freq_mask(source_features, source_mask)

        source_features = source_features.masked_fill(
            ~source_mask.unsqueeze(1).unsqueeze(-1).expand_as(source_features),
            0.0)
        state = self._encode(source_features, source_lengths)
        source_lengths = util.get_lengths_from_binary_sequence_mask(
            state["source_mask"])
        target_tokens["mask"] = (target_tokens[self._target_namespace] !=
                                 self._pad_index).bool()

        if self._phn_ctc_layer and \
            (self._phn_target_namespace in self._target_granularity or self._train_at_phn_level):
            raise NotImplementedError
            # logits = self._projection_layer(state["encoder_outputs"])
            # phn_ctc_output_dict = self._phn_ctc_layer(logits, source_lengths, target_tokens)
            # output_dict.update({f"phn_ctc_{key}": value for key, value in phn_ctc_output_dict.items()})

        if self._rnnt_layer is not None and self._rnnt_layer.loss_ratio > 0.0:
            rnnt_output_dict = self._rnnt_layer(state["encoder_outputs"],
                                                source_lengths, target_tokens)
            output_dict.update({
                f"rnnt_{key}": value
                for key, value in rnnt_output_dict.items()
            })
        if self._ctc_layer is not None and self._ctc_layer.loss_ratio > 0.0:
            logits = self._projection_layer(state["encoder_outputs"])
            ctc_output_dict = self._ctc_layer(logits, source_lengths,
                                              target_tokens)
            output_dict.update({
                f"ctc_{key}": value
                for key, value in ctc_output_dict.items()
            })

        if target_tokens and self._att_ratio > 0.0 and \
            self._target_granularity == self._target_namespace:
            targets = target_tokens[self._target_namespace]
            output_dict["target_tokens"] = targets
            target_mask = util.get_text_field_mask(target_tokens)
            if self._train_at_phn_level:
                raise NotImplementedError
                # state = self._get_phn_level_representations(
                #     state["encoder_outputs"].detach().requires_grad_(True),
                #     state["source_mask"],
                #     output_dict["phn_ctc"])

            state = self._init_decoder_state(state)
            output_dict.update(self._forward_loop(state, target_tokens))
            self._logs["att_wer"](output_dict["predictions"], targets)

            if self._dep_parser or self._pos_tagger:
                relevant_mask = target_mask[:, 1:]
                attention_contexts, _ = _remove_eos(
                    output_dict["attention_contexts"], relevant_mask)
                if segments is not None:
                    segments, _ = remove_sentence_boundaries(
                        segments, target_mask)
                    attention_contexts, _ = \
                        char_to_word(attention_contexts, segments)
                contexts = {"tokens": attention_contexts}
                if self._dep_parser:
                    parser_outputs = self._dep_parser(contexts, pos_tags,
                                                      metadata, head_tags,
                                                      head_indices)
                    parser_outputs["dep_loss"] = parser_outputs.pop("loss")
                    output_dict.update(parser_outputs)
                if self._pos_tagger:
                    tagger_outputs = self._pos_tagger(contexts, pos_tags,
                                                      metadata)
                    tagger_outputs["pos_loss"] = tagger_outputs.pop("loss")
                    output_dict.update(tagger_outputs)

        if not self.training:
            if self._target_granularity == self._target_namespace:
                if self._att_ratio > 0.0:
                    state = self._init_decoder_state(state)
                    predictions = self._forward_beam_search(state)
                    output_dict.update(predictions)
                    if target_tokens:
                        targets = target_tokens[self._target_namespace]
                        # shape: (batch_size, beam_size, max_sequence_length)
                        top_k_predictions = output_dict["predictions"]
                        # shape: (batch_size, max_predicted_sequence_length)
                        best_predictions = top_k_predictions[:, 0, :]
                        self._logs["att_bleu"](best_predictions, targets)
                        self._logs["att_wer"](best_predictions, targets)
                    log_dict = self.decode(output_dict)
                    verbose_target = [
                        self._indices_to_tokens(tokens.tolist()[1:])
                        for tokens in target_tokens[self._target_namespace]
                    ]
                    verbose_best_pred = [
                        beams[0] for beams in log_dict["predicted_tokens"]
                    ]
                    sep = " " if self._target_namespace == 'tokens' else ""
                    with open(f"preds.{epoch_num[0]}.txt", "a+") as fp:
                        fp.write("\n".join([
                            sep.join(
                                map(lambda s: re.sub(self._blank, " ", s),
                                    words)) for words in verbose_best_pred
                        ]))
                        fp.write("\n")
                    with open(f"golds.{epoch_num[0]}.txt", "a+") as fp:
                        fp.write("\n".join([
                            sep.join(
                                map(lambda s: re.sub(self._blank, " ", s),
                                    words)) for words in verbose_target
                        ]))
                        fp.write("\n")
                    # for gold, pred in zip(verbose_target, verbose_best_pred):
                    #     print(gold, pred)

        if self.training:
            output_dict = self._collect_losses(
                output_dict,
                ctc=(self._ctc_layer.loss_ratio if self._ctc_layer else 0),
                rnnt=(self._rnnt_layer.loss_ratio if self._rnnt_layer else 0),
                att=self._att_ratio,
                dal=self._latency_penalty,
                dep=self._dep_ratio,
                pos=self._pos_ratio)
            if torch.isnan(output_dict["loss"]).any() or \
                    (torch.abs(output_dict["loss"]) == float('inf')).any():
                for key, _ in output_dict.items():
                    if "loss" in key:
                        output_dict[key] = output_dict[key].new_zeros(
                            size=(), requires_grad=True).clone()
        self._update_metrics(output_dict)

        return output_dict