Exemplo n.º 1
0
    def compute_representations(
        self,  # type: ignore
        span_embeddings,  # (1, Ns, E)
        coref_labels: torch.IntTensor,  # (1, Ns, C)
        type_to_cluster_ids: Dict[str, List[int]],
        relation_to_cluster_ids: Dict[int, List[int]] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        if coref_labels.sum() == 0:
            return {"loss": 0.0, "metadata": metadata}

        cluster_type_embeddings = self.map_cluster_to_type_embeddings(
            type_to_cluster_ids)  # (1, C, E)

        sum_embeddings = (span_embeddings.unsqueeze(2) *
                          coref_labels.float().unsqueeze(-1)).sum(1)
        length_embeddings = (coref_labels.unsqueeze(-1).sum(1) + 1e-5)

        cluster_span_embeddings = sum_embeddings / length_embeddings

        paragraph_cluster_mask = (coref_labels.sum(1) > 0).float().unsqueeze(
            -1)  # (P, C, 1)

        paragraph_cluster_embeddings = cluster_span_embeddings * paragraph_cluster_mask + cluster_type_embeddings * (
            1 - paragraph_cluster_mask)  # (P, C, E)

        assert (paragraph_cluster_embeddings.shape[1] == coref_labels.shape[2]
                and paragraph_cluster_embeddings.shape[2]
                == span_embeddings.shape[-1])

        paragraph_cluster_embeddings = torch.cat(
            [
                paragraph_cluster_embeddings,
                self._bias_vectors.expand(
                    paragraph_cluster_embeddings.shape[0], -1, -1)
            ],
            dim=1,
        )  # (P, C+4, E)
        n_true_clusters = coref_labels.shape[-1]

        candidate_relations, candidate_relations_labels, candidate_relations_types = self.generate_product(
            type_to_clusters_map=type_to_cluster_ids,
            relation_to_clusters_map=relation_to_cluster_ids,
            n_true_clusters=n_true_clusters,
        )

        candidate_relations_tensor = torch.LongTensor(candidate_relations).to(
            span_embeddings.device)  # (R, 4)
        candidate_relations_labels_tensor = torch.LongTensor(
            candidate_relations_labels).to(span_embeddings.device)  # (R, )

        if len(candidate_relations) == 0:
            return {"loss": 0.0, "metadata": metadata}

        all_relation_embeddings = util.batched_index_select(
            paragraph_cluster_embeddings,
            candidate_relations_tensor.unsqueeze(0).expand(
                paragraph_cluster_embeddings.shape[0], -1, -1),
        )  # (P, R', n, E)

        relation_scores, relation_logits = self.get_relation_scores(
            all_relation_embeddings)  # (1, R')
        output_dict = {}
        output_dict["relations_candidates_list"] = candidate_relations
        output_dict["relation_labels"] = candidate_relations_labels
        output_dict["relation_types"] = candidate_relations_types
        output_dict["doc_id"] = metadata[0]["doc_id"]
        output_dict["metadata"] = metadata
        output_dict["relation_scores"] = relation_scores
        output_dict["relation_logits"] = relation_logits

        if relation_to_cluster_ids is not None:
            output_dict = self.predict_labels(
                relation_scores, relation_logits,
                candidate_relations_labels_tensor, output_dict)

        return output_dict
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            evidence: torch.IntTensor = None,
            pad_idx=-1,
            max_select=5,
            gamma=0.95,
            teacher_forcing_ratio=1,
            features=None,
            metadata=None) -> Dict[str, torch.Tensor]:

        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        evidence : torch.IntTensor, optional (default = None)
            From a ``ListField``
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        #print([int(i.data[0]) for i in premise['tokens'][0,0]])

        premise_mask = get_text_field_mask(premise,
                                           num_wrapping_dims=1).float()

        hypothesis_mask = get_text_field_mask(hypothesis).float()

        aggregated_input = self._sentence_selection_esim(premise,
                                                         hypothesis,
                                                         premise_mask,
                                                         hypothesis_mask,
                                                         wrap_output=True,
                                                         features=features)

        batch_size, num_evidence, max_premise_length = premise_mask.shape
        #print(premise_mask.shape)
        aggregated_input = aggregated_input.view(batch_size, num_evidence, -1)
        evidence_mask = premise_mask.sum(dim=-1).gt(0)
        evidence_len = evidence_mask.view(batch_size, -1).sum(dim=-1)
        #print(aggregated_input.shape)
        #print(evidence_len)

        #for each element in the batch
        valid_indices = []
        indices = []
        probs = []
        baselines = []
        states = []
        selected_evidence_lengths = []
        for i in range(evidence.size(0)):
            #print(label[i].data[0], evidence[i])

            gold_evidence = None
            #teacher forcing, give a list of indices and get the probabilities
            #print(label[i])
            try:
                curr_label = label[i].data[0]
            except IndexError:
                curr_label = label[i].item()

            if random.random(
            ) > teacher_forcing_ratio and curr_label != self._nei_label and float(
                    evidence[i].ne(pad_idx).sum()) > 0:
                gold_evidence = evidence[i]
            #print(gold_evidence)

            output = self._ptr_extract_summ(aggregated_input[i],
                                            max_select,
                                            evidence_mask[i],
                                            gold_evidence,
                                            beam_size=self._beam_size)
            #print(output['states'].shape)
            #print(idxs)
            states.append(output.get('states', []))

            valid_idx = []
            try:
                curr_evidence_len = evidence_len[i].data[0]
            except IndexError:
                curr_evidence_len = evidence_len[i].item()
            for idx in output['idxs'][:min(max_select, curr_evidence_len)]:
                try:
                    curr_idx = idx.view(-1).data[0]
                except IndexError:
                    curr_idx = idx.view(-1).item()

                if curr_idx == num_evidence:
                    break
                valid_idx.append(curr_idx)

                if valid_idx[-1] >= curr_evidence_len:
                    valid_idx[-1] = 0

            #TODO: if it selects none, use the first one?

            selected_evidence_lengths.append(len(valid_idx))
            #print(selected_evidence_lengths[-1])
            indices.append(valid_idx)
            if 'scores' in output:
                baselines.append(output['scores'][:len(valid_idx)])
            if 'probs' in output:
                probs.append(output['probs'][:len(valid_idx)])

            valid_indices.append(torch.LongTensor(valid_idx + \
                                             [-1]*(max_select-len(valid_idx))))
        '''
        for q in range(label.size(0)):
            if selected_evidence_lengths[q] >= 5:
                continue
            print(label[q])
            print(evidence[q])
            print(valid_indices[q])
            if len(baselines):
                print(probs[q][0].probs)            
                print(baselines[q])
        '''

        output_dict = {'predicted_sentences': torch.stack(valid_indices)}

        predictions = torch.autograd.Variable(torch.stack(valid_indices))

        selected_premise = {}
        index = predictions.unsqueeze(2).expand(batch_size, max_select,
                                                max_premise_length)
        #B x num_selected
        l = torch.autograd.Variable(
            len_mask(selected_evidence_lengths,
                     max_len=max_select,
                     dtype=torch.FloatTensor))

        index = index * l.long().unsqueeze(-1)
        if torch.cuda.is_available() and premise_mask.is_cuda:
            idx = premise_mask.get_device()
            index = index.cuda(idx)
            l = l.cuda(idx)
            predictions = predictions.cuda(idx)

        if self._use_decoder_states:
            states = torch.cat(states, dim=0)
            label_sequence = make_label_sequence(predictions,
                                                 evidence,
                                                 label,
                                                 pad_idx=pad_idx,
                                                 nei_label=self._nei_label)
            #print(states.shape)
            batch_size, max_length, _ = states.shape
            label_logits = self._entailment_esim(
                features=states.view(batch_size * max_length, 1, -1))
            if 'loss' not in output_dict:
                output_dict['loss'] = 0
            output_dict['loss'] += sequence_loss(label_logits.view(
                batch_size, max_length, -1),
                                                 label_sequence,
                                                 self._evidence_loss,
                                                 pad_idx=pad_idx)
            output_dict['label_sequence_logits'] = label_logits.view(
                batch_size, max_length, -1)
            label_logits = output_dict['label_sequence_logits'][:, -1, :]
        else:
            for key in premise:
                selected_premise[key] = torch.gather(premise[key],
                                                     dim=1,
                                                     index=index)

            selected_mask = torch.gather(premise_mask, dim=1, index=index)

            selected_mask = selected_mask * l.unsqueeze(-1)

            selected_features = None
            if features is not None:
                index = predictions.unsqueeze(2).expand(
                    batch_size, max_select, features.size(-1))
                index = index * l.long().unsqueeze(-1)
                selected_features = torch.gather(features, dim=1, index=index)

            #UNDO!!!!!
            selected_features = selected_features[:, :, :200]
            label_logits = self._entailment_esim(selected_premise,
                                                 hypothesis,
                                                 premise_mask=selected_mask,
                                                 features=selected_features)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        #print(label_probs[0])
        '''
        key = 'tokens'
        for q in range(premise[key].size(0)):
            print(index[q,:,0])
            print([int(i.data[0]) for i in hypothesis[key][q]])
            print([self.vocab._index_to_token[key][i.data[0]] for i in hypothesis[key][q]])
            print([int(i.data[0]) for i in premise[key][q,0]])
            print([self.vocab._index_to_token[key][i.data[0]] for i in premise[key][q,0]])
            print([self.vocab._index_to_token[key][i.data[0]] for i in premise[key][q,index[q,0,0].data[0]]])            
            print([self.vocab._index_to_token[key][i.data[0]] for i in selected_premise[key][q,0]])
        
            print([int(i.data[0]) for i in premise_mask[q,0]])
            print(l[q])
            print([int(i.data[0]) for i in premise_mask[q,index[q,0,0].data[0]]])            
            for z in range(5):
                print([int(i.data[0]) for i in selected_mask[q,z]])

            print(label[q], label_probs[q])
        '''

        output_dict.update({
            "label_logits": label_logits,
            "label_probs": label_probs
        })

        #get fever score, recall, and accuracy

        if len(label.shape) > 1:
            self._accuracy(label_logits, label.squeeze(-1))
        else:
            self._accuracy(label_logits, label)

        fever_reward = self._fever(label_logits,
                                   label.squeeze(-1),
                                   predictions,
                                   evidence,
                                   indices=True,
                                   pad_idx=pad_idx,
                                   metadata=metadata)

        if not self._fix_sentence_extraction_params:
            #multiply the reward for the support/refute labels by a constant so that the model selects the correct evidence instead of just trying to predict the not enough info labels
            fever_reward = fever_reward * label.squeeze(-1).ne(
                self._nei_label
            ) * self._ei_reward_weight + fever_reward * label.squeeze(-1).eq(
                self._nei_label)

            #compute discounted reward
            rewards = []
            #print(fever_reward[0])
            avg_reward = 0
            for i in range(evidence.size(0)):
                avg_reward += float(fever_reward[i])
                #rewards.append(gamma ** torch.range(selected_evidence_lengths[i]-1,0,-1) * float(fever_reward[i]))
                rewards.append(
                    gamma**torch.arange(selected_evidence_lengths[i]).float() *
                    fever_reward[i].float())
            #print(fever_reward[0])
            #print(rewards[0])

            reward = torch.autograd.Variable(torch.cat(rewards),
                                             requires_grad=False)
            if torch.cuda.is_available() and fever_reward.is_cuda:
                idx = fever_reward.get_device()
                reward = reward.cuda(idx)

            #print(reward)
            if len(baselines):
                indices = list(itertools.chain(*indices))
                probs = list(itertools.chain(*probs))
                baselines = list(itertools.chain(*baselines))

                #print(baselines)

                # standardize rewards
                reward = (reward - reward.mean()) / (
                    reward.std() + float(np.finfo(np.float32).eps))

                #print(reward)
                baseline = torch.cat(baselines).squeeze()
                avg_advantage = 0
                losses = []
                for action, p, r, b in zip(indices, probs, reward, baseline):
                    #print(action, p, r, b)
                    action = torch.autograd.Variable(torch.LongTensor([action
                                                                       ]))
                    if torch.cuda.is_available() and r.is_cuda:
                        idx = r.get_device()
                        action = action.cuda(idx)

                    advantage = r - b
                    #print(r, b, advantage)
                    avg_advantage += advantage
                    losses.append(-p.log_prob(action) *
                                  (advantage / len(indices)))  # divide by T*B
                    #print(losses[-1])

                critic_loss = F.mse_loss(baseline, reward)

                output_dict['loss'] = critic_loss + sum(losses)

                #output_dict['loss'].backward(retain_graph=True)
                #grad_log = self.grad_fn()
                #print(grad_log)

                try:
                    output_dict['advantage'] = avg_advantage.data[0] / len(
                        indices)
                    output_dict['mse'] = critic_loss.data[0]
                except IndexError:
                    output_dict['advantage'] = avg_advantage.item() / len(
                        indices)
                    output_dict['mse'] = critic_loss.item()

            #output_dict['reward'] = avg_reward / evidence.size(0)

        if self.training and self._train_gold_evidence:

            if 'loss' not in output_dict:
                output_dict['loss'] = 0
            if evidence.sum() != -1 * torch.numel(evidence):
                if len(evidence.shape) > 2:
                    evidence = evidence.squeeze(-1)
                #print(evidence_len.long().data.cpu().numpy().tolist())
                #print(evidence.shape, evidence_len.shape)
                #print(evidence, evidence_len)
                output = self._ptr_extract_summ(
                    aggregated_input, None, None, evidence,
                    evidence_len.long().data.cpu().numpy().tolist())
                #print(output['states'].shape)

                loss = sequence_loss(output['scores'][:, :-1, :],
                                     evidence,
                                     self._evidence_loss,
                                     pad_idx=pad_idx)

                output_dict['loss'] += self.lambda_weight * loss

        if not self._fix_entailment_params:
            if self._use_decoder_states:
                if self.training:
                    label_sequence = make_label_sequence(
                        evidence,
                        evidence,
                        label,
                        pad_idx=pad_idx,
                        nei_label=self._nei_label)
                    batch_size, max_length, _ = output['states'].shape
                    label_logits = self._entailment_esim(
                        features=output['states'][:, 1:, :].contiguous().view(
                            batch_size * (max_length - 1), 1, -1))
                    if 'loss' not in output_dict:
                        output_dict['loss'] = 0
                    #print(label_logits.shape, label_sequence.shape)
                    output_dict['loss'] += sequence_loss(label_logits.view(
                        batch_size, max_length - 1, -1),
                                                         label_sequence,
                                                         self._evidence_loss,
                                                         pad_idx=pad_idx)
            else:
                #TODO: only update classifier if we have correct evidence
                #evidence_reward = self._fever_evidence_only(label_logits, label.squeeze(-1),
                #                                            predictions, evidence,
                #                                            indices=True, pad_idx=pad_idx)
                ###print(evidence_reward)
                ###print(label)
                #mask = evidence_reward > 0
                #target = mask * label.byte() + mask.eq(0) * self._nei_label

                mask = fever_reward != 2**7
                target = label.view(-1).masked_select(mask)

                ###print(target)

                mask = fever_reward != 2**7
                logit = label_logits.masked_select(
                    mask.unsqueeze(1).expand_as(
                        label_logits)).contiguous().view(
                            -1, label_logits.size(-1))

                loss = self._loss(
                    logit,
                    target.long())  #label_logits, label.long().view(-1))
                if 'loss' in output_dict:
                    output_dict["loss"] += self.lambda_weight * loss
                else:
                    output_dict["loss"] = self.lambda_weight * loss

        return output_dict