Exemplo n.º 1
0
    def selection_step(self,
                       cur_sum,
                       cur_len,
                       sentence_sums,
                       sentence_lens,
                       sentence_mask,
                       sentence_label=None):
        combined_sentence_embeddings = cur_sum.unsqueeze(1) + sentence_sums
        combined_len = cur_len.unsqueeze(1) + sentence_lens

        pooled_embeddings = combined_sentence_embeddings / combined_len
        sentence_logits = self.sentence_classifier(pooled_embeddings).squeeze(
            -1)
        sentence_logits = utils.mask_tensor(sentence_logits,
                                            sentence_mask.detach())

        num_sentences = combined_sentence_embeddings.size(1)
        if self.training:
            if self.config.teacher_forcing and sentence_label is not None:
                one_hot = utils.convert_single_one_hot(sentence_label,
                                                       num_sentences)
            else:
                one_hot = F.gumbel_softmax(sentence_logits, hard=True)
        else:
            one_hot = torch.argmax(sentence_logits, -1)
            one_hot = utils.convert_single_one_hot(one_hot, num_sentences)

        sentence_mask = (1 - one_hot) * sentence_mask
        one_hot = one_hot.unsqueeze(-1)

        new_embedding = (one_hot * combined_sentence_embeddings).sum(dim=1)
        new_len = (one_hot * combined_len).sum(dim=1)

        return sentence_logits, new_embedding, new_len, sentence_mask, one_hot.squeeze(
            -1)
Exemplo n.º 2
0
    def forward(self, c, q, cw, mask):
        c = c[:q.size(0)]

        cq = self.linear_cq(torch.cat([c, q], dim=1))

        ca = self.linear_ca(cq[:, None, :] * cw)
        ca = mask_tensor(ca, mask, float('-inf'))
        cv = F.softmax(ca, 1)
        c = (cv * cw).sum(1)
        return c, cv
Exemplo n.º 3
0
    def _calc_mention_logits(self, start_mention_reps, end_mention_reps):
        start_mention_logits = self.mention_start_classifier(start_mention_reps).squeeze(-1)  # [batch_size, seq_length]
        end_mention_logits = self.mention_end_classifier(end_mention_reps).squeeze(-1)  # [batch_size, seq_length]

        temp = self.mention_s2e_classifier(start_mention_reps)  # [batch_size, seq_length]
        joint_mention_logits = torch.matmul(temp,
                                            end_mention_reps.permute([0, 2, 1]))  # [batch_size, seq_length, seq_length]

        mention_logits = joint_mention_logits + start_mention_logits.unsqueeze(-1) + end_mention_logits.unsqueeze(-2)
        mention_mask = self._get_mention_mask(mention_logits)  # [batch_size, seq_length, seq_length]
        mention_logits = mask_tensor(mention_logits, mention_mask)  # [batch_size, seq_length, seq_length]
        return mention_logits
def _create_sentence_embeddings(model, ids, model_input, sentence_indicators):
    d = {}
    sim = torch.nn.CosineSimilarity(-1)
    for idx in tqdm(range(len(ids))):
        inputs = {'input_ids': torch.tensor([model_input['input_ids'][idx]]).cuda(),
                  'attention_mask': torch.tensor([model_input['attention_mask'][idx]]).cuda()}
        sentence_indicator = torch.tensor([sentence_indicators[idx]]).cuda()
        output = model(**inputs)
        hidden_states = output[0]
        sentences = []
        sentence_lens = []

        for i in range(sentence_indicator.max() + 1):
            mask = (sentence_indicator == i).long().cuda()
            sentence_embedding = torch.sum(hidden_states * mask.unsqueeze(-1), dim=1)
            sentence_len = mask.sum(dim=1).view(-1, 1)
            sentences.append(sentence_embedding)
            sentence_lens.append(sentence_len)

        sentences = torch.stack(sentences, dim=1)
        sentence_lens = torch.stack(sentence_lens, dim=1)
        sentence_lens = sentence_lens.clamp(min=1)
        pooled_embedding = (hidden_states*inputs['attention_mask'].unsqueeze(-1)).sum(1).unsqueeze(1)

        sentence_mask = utils.get_sentence_mask(sentence_indicator, sentences.size(1)).float()

        cur = torch.zeros(sentences.size(0), sentences.size(-1)).cuda()
        cur_len = torch.zeros(sentence_lens.size(0), sentence_lens.size(-1)).cuda()
        l = []
        for i in range(3):
            candidates = cur.unsqueeze(1) + sentences
            candidate_lens = cur_len.unsqueeze(1) + sentence_lens
            cur_embedding = candidates / candidate_lens
            scores = sim(cur_embedding, pooled_embedding)
            
            scores = utils.mask_tensor(scores, sentence_mask)
            index = torch.argmax(scores)
            cur = candidates[range(1), index]
            cur_len = candidates[range(1), index]
#            pooled_embedding -= sentences[range(1),index]
            sentence_mask[range(1), index] = 0
            l.append(index.item())

        d[ids[idx]] = l

    pickle.dump(d, open('sim_oracle5.p', 'wb'))
    return d
Exemplo n.º 5
0
    def forward(self, x, q, lengths, mask):
        k = self.image_extractor(x)

        embedding = self.embed(q)

        cw = self.lstm(embedding)[0]

        indices = lengths[:, None, None].repeat(1, 1, self.d)
        cw_1 = cw[:, 0, self.d:]
        cw_s = cw[:, :, :self.d].gather(1, indices).squeeze()
        q_ = torch.cat([cw_1, cw_s], -1)
        q = torch.stack([self.linear_q[i](q_) for i in range(self.p)])

        cw = self.linear_cw(cw)
        cw = mask_tensor(cw, mask, 0.)

        return k, q_, q, cw
Exemplo n.º 6
0
    def _get_marginal_log_likelihood_loss(self, coref_logits, cluster_labels_after_pruning, span_mask):
        """
        :param coref_logits: [batch_size, max_k, max_k]
        :param cluster_labels_after_pruning: [batch_size, max_k, max_k]
        :param span_mask: [batch_size, max_k]
        :return:
        """
        gold_coref_logits = mask_tensor(coref_logits, cluster_labels_after_pruning)

        gold_log_sum_exp = torch.logsumexp(gold_coref_logits, dim=-1)  # [batch_size, max_k]
        all_log_sum_exp = torch.logsumexp(coref_logits, dim=-1)  # [batch_size, max_k]

        gold_log_probs = gold_log_sum_exp - all_log_sum_exp
        losses = - gold_log_probs
        losses = losses * span_mask
        per_example_loss = torch.sum(losses, dim=-1)  # [batch_size]
        if self.normalise_loss:
            per_example_loss = per_example_loss / losses.size(-1)
        loss = per_example_loss.mean()
        return loss
Exemplo n.º 7
0
 def _mask_antecedent_logits(self, antecedent_logits, span_mask):
     # We now build the matrix for each pair of spans (i,j) - whether j is a candidate for being antecedent of i?
     antecedents_mask = torch.ones_like(antecedent_logits, dtype=self.dtype).tril(diagonal=-1)  # [batch_size, k, k]
     antecedents_mask = antecedents_mask * span_mask.unsqueeze(-1)  # [batch_size, k, k]
     antecedent_logits = mask_tensor(antecedent_logits, antecedents_mask)
     return antecedent_logits