Esempio n. 1
0
    def forward(self, v, b, q, e, labels):
        """Forward

        v: [batch, num_objs, obj_dim]
        b: [batch, num_objs, b_dim]
        q: [batch, seq_length]
        e: [batch, num_entities]

        return: logits, not probs
        """
        assert q.size(1) > e.data.max(), 'len(q)=%d > e_pos.max()=%d' % (
            q.size(1), e.data.max())
        MINUS_INFINITE = -99
        if 's' in self.op:
            v = torch.cat([v, b], 2)
        w_emb = self.w_emb(q)
        q_emb = self.q_emb.forward_all(w_emb)  # [batch, q_len, q_dim]
        # entity positions
        q_emb = utils.batched_index_select(q_emb, 1, e)

        att = self.v_att.forward_all(v, q_emb, True, True,
                                     MINUS_INFINITE)  # b x g x v x q
        mask = (e == 0).unsqueeze(1).unsqueeze(2).expand(att.size())
        mask[:, :, :, 0].data.fill_(0)  # at least one entity per sentence
        att.data.masked_fill_(mask.data, MINUS_INFINITE)

        return None, att
Esempio n. 2
0
    def forward(self, batch):
        embeddings = self.model(
            batch['input_ids'],
            attention_mask=batch['attention_masks'],
            token_type_ids=batch['segment_ids'],
        )[0]  # [B, L_doc, 768]

        cls_embeddings = batched_index_select(embeddings, batch['index_ids'])
        cls_embeddings *= (batch['label_ids'] != -1).float().unsqueeze(2)

        logits = self.classifier(
            self.dropout(cls_embeddings))  # [B, L_cls, 768]
        logits = logits.transpose(1, 2)  # [B, 2, L_cls]

        return logits
    def actor_layer(self, batch_state, mask, n_label_sents=None):
        """
        Determines which sentences to extract for each of the documents represented by batch_state

        :param batch_state:     A torch.tensor representing sentence embeddings of each document within the batch.
                                Shape: (batch_size, n_doc_sentences, embedding_dim)
        :param mask:            A torch.tensor of booleans indicating whether or not the document within the batch
                                actually has the sentence. This is necessary because we've batched multiple documents
                                together of various lengths.
        :param n_label_sents:   An optional list containing number of extracted sentences in summary labels (oracle)
        :return:                A tuple containing:
                                 - action_dists: list(Categorical()) containing categorical distributions. Each entry
                                                 represents the distribution amongst sentences to extract at a given
                                                 step for all batches.
                                 - action_indices: torch.tensor() containing the indicies of extracted sentences
                                                    Shape: (batch_size, n_extracted_sentences, embedding_dim)
                                 - ext_sents: A torch.tensor() containing extracted sentence embeddings.
                                              Shape: (batch_size, n_extracted_sentences, embedding_dim)
                                 - n_ext_sents: A torch.tensor() where entries show # of sentences extracted per sample
        """
        # Obtain distribution amongst actions
        batch_state, mask = self.add_stop_action(batch_state, mask)

        # Obtain number of samples in batch
        batch_size = batch_state.shape[0]
        max_n_doc_sents = batch_state.shape[1]
        embedding_dim = batch_state.shape[2]

        # Obtain maximum number of sentences to extract
        if n_label_sents is None:
            n_doc_sents = mask.sum(dim=1)
            batch_max_n_ext_sents = torch.tensor([self.max_n_ext_sents] *
                                                 batch_size)
            batch_max_n_ext_sents = torch.min(batch_max_n_ext_sents.float(),
                                              n_doc_sents)
        else:
            batch_max_n_ext_sents = n_label_sents

        # Create variables to stop extraction loop
        max_n_ext_sents = batch_max_n_ext_sents.max(
        )  # Maximum number of sentences to extract
        stop_action_idx = max_n_doc_sents - 1  # Previously appended stop_action embedding
        is_stop_action = torch.zeros(batch_size).bool()

        src_doc_lengths = torch.sum(mask, dim=1)
        batch_state = torch.nn.utils.rnn.pack_padded_sequence(
            batch_state,
            lengths=src_doc_lengths,
            batch_first=True,
            enforce_sorted=False)

        # Extraction loop
        action_indices, ext_sents, action_dists, stop_action_list = list(
        ), list(), list(), list()
        n_ext_sents = 0
        is_first_sent = True
        extraction_labels = None
        while True:
            # Obtain distribution amongst sentences to extract
            if is_first_sent:
                action_probs, __ = self.extraction_model.forward(
                    batch_state, mask)
                is_first_sent = False
            else:
                action_probs, __ = self.extraction_model.forward(
                    sent_embeddings=batch_state,
                    sent_mask=mask,
                    extraction_indicator=extraction_labels,
                    use_init_embedding=True)
            action_probs = action_probs[:, -1:, :]
            action_dist = Categorical(action_probs)

            # Sample sentence to extract
            ext_sent_indices = action_dist.sample().T

            # Embeddings of sentences to extract
            ext_sent_embeddings = batched_index_select(batch_state, 1,
                                                       ext_sent_indices)

            # Collect
            action_dists.append(action_dist)
            ext_sents.append(ext_sent_embeddings)
            action_indices.append(ext_sent_indices)

            # Form extraction_labels
            extraction_labels = torch.zeros(batch_size, max_n_doc_sents)
            already_ext_indices = torch.cat(action_indices)
            extraction_labels[torch.arange(batch_size),
                              already_ext_indices] = 1

            # Track number of sentences extracted from article
            n_ext_sents = n_ext_sents + 1

            # Check to see if should stop extracting sentences
            # Todo: Fix this, the mask ALWAYS masks out the stop action...
            is_stop_action = is_stop_action | (ext_sent_indices >=
                                               stop_action_idx)
            stop_action_list.append(is_stop_action)
            all_samples_stop = torch.sum(is_stop_action) >= batch_size
            is_long_enough = n_ext_sents >= max_n_ext_sents
            if all_samples_stop or is_long_enough:
                break

        action_indices = torch.stack(action_indices).T.squeeze(1)
        n_ext_sents = (~torch.stack(stop_action_list).squeeze(1).T).sum(dim=1)
        ext_sents = torch.stack(ext_sents).transpose(0, 1).squeeze()
        return action_dists, action_indices, ext_sents, n_ext_sents