def selection_loop(self,
                       hidden_states,
                       sentence_indicator,
                       sentence_labels,
                       pmi_features=None,
                       two_selections=False):
        all_sentence_logits = []
        sentences = []
        sentence_lens = []
        for i in range(sentence_indicator.max() + 1):
            mask = (sentence_indicator == i).long().cuda()
            sentence_embedding = torch.sum(hidden_states * mask.unsqueeze(-1),
                                           dim=1)
            sentence_len = mask.sum(dim=1).view(-1, 1)
            sentences.append(sentence_embedding)
            sentence_lens.append(sentence_len)

        sentences = torch.stack(sentences, dim=1)
        sentence_lens = torch.stack(sentence_lens, dim=1)
        sentence_lens = sentence_lens.clamp(min=1)

        if self.config.use_pmi and pmi_features is not None:
            pmi_features = pmi_features[:, :sentence_indicator.max() +
                                        1].unsqueeze(-1)

            sentences = torch.cat((sentences, pmi_features), dim=-1)

        #        zero_len_mask = sentence_lens == 0
        #        sentence_lens = sentence_lens + zero_len_mask.float()

        cur_embedding = torch.zeros(sentences.size(0),
                                    sentences.size(-1)).cuda()
        cur_len = torch.zeros(sentence_lens.size(0),
                              sentence_lens.size(-1)).cuda()

        selected_one_hot = torch.zeros(sentences.size(0),
                                       sentences.size(1)).cuda()
        selected_one_hot1 = torch.zeros(sentences.size(0),
                                        sentences.size(1)).cuda()
        sentence_mask = utils.get_sentence_mask(sentence_indicator,
                                                sentences.size(1)).float()

        for i in range(self.config.extraction_k):
            sentence_logits, cur_embedding, cur_len, sentence_mask, one_hot = self.selection_step(
                cur_embedding, cur_len, sentences, sentence_lens,
                sentence_mask,
                sentence_labels[:, i] if sentence_labels is not None else None)
            selected_one_hot = selected_one_hot + one_hot
            if i < 3:
                selected_one_hot1 = selected_one_hot1 + one_hot
            all_sentence_logits.append(sentence_logits)
        selected_one_hot = selected_one_hot.clamp(max=1)
        selectd_one_hot1 = selected_one_hot1.clamp(max=1)
        if two_selections:
            return selected_one_hot, selected_one_hot1, all_sentence_logits
        return selected_one_hot, all_sentence_logits
def _create_sentence_embeddings(model, ids, model_input, sentence_indicators):
    d = {}
    sim = torch.nn.CosineSimilarity(-1)
    for idx in tqdm(range(len(ids))):
        inputs = {'input_ids': torch.tensor([model_input['input_ids'][idx]]).cuda(),
                  'attention_mask': torch.tensor([model_input['attention_mask'][idx]]).cuda()}
        sentence_indicator = torch.tensor([sentence_indicators[idx]]).cuda()
        output = model(**inputs)
        hidden_states = output[0]
        sentences = []
        sentence_lens = []

        for i in range(sentence_indicator.max() + 1):
            mask = (sentence_indicator == i).long().cuda()
            sentence_embedding = torch.sum(hidden_states * mask.unsqueeze(-1), dim=1)
            sentence_len = mask.sum(dim=1).view(-1, 1)
            sentences.append(sentence_embedding)
            sentence_lens.append(sentence_len)

        sentences = torch.stack(sentences, dim=1)
        sentence_lens = torch.stack(sentence_lens, dim=1)
        sentence_lens = sentence_lens.clamp(min=1)
        pooled_embedding = (hidden_states*inputs['attention_mask'].unsqueeze(-1)).sum(1).unsqueeze(1)

        sentence_mask = utils.get_sentence_mask(sentence_indicator, sentences.size(1)).float()

        cur = torch.zeros(sentences.size(0), sentences.size(-1)).cuda()
        cur_len = torch.zeros(sentence_lens.size(0), sentence_lens.size(-1)).cuda()
        l = []
        for i in range(3):
            candidates = cur.unsqueeze(1) + sentences
            candidate_lens = cur_len.unsqueeze(1) + sentence_lens
            cur_embedding = candidates / candidate_lens
            scores = sim(cur_embedding, pooled_embedding)
            
            scores = utils.mask_tensor(scores, sentence_mask)
            index = torch.argmax(scores)
            cur = candidates[range(1), index]
            cur_len = candidates[range(1), index]
#            pooled_embedding -= sentences[range(1),index]
            sentence_mask[range(1), index] = 0
            l.append(index.item())

        d[ids[idx]] = l

    pickle.dump(d, open('sim_oracle5.p', 'wb'))
    return d