示例#1
0
    def test_masked_mean(self):
        # Testing the general masked 1D case.
        vector_1d = torch.FloatTensor([1.0, 12.0, 5.0])
        mask_1d = torch.FloatTensor([1.0, 0.0, 1.0])
        vector_1d_mean = util.masked_mean(vector_1d, mask_1d, dim=0).data.numpy()
        assert_array_almost_equal(vector_1d_mean, 3.0)

        # Testing if all masks are zero, the output will be arbitrary, but it should not be nan.
        vector_1d = torch.FloatTensor([1.0, 12.0, 5.0])
        mask_1d = torch.FloatTensor([0.0, 0.0, 0.0])
        vector_1d_mean = util.masked_mean(vector_1d, mask_1d, dim=0).data.numpy()
        assert not numpy.isnan(vector_1d_mean).any()

        # Testing batch value and batch masks
        matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
        matrix_mean = util.masked_mean(matrix, mask, dim=-1).data.numpy()
        assert_array_almost_equal(matrix_mean, numpy.array([3.0, -1.5]))

        # Testing keepdim for batch value and batch masks
        matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
        matrix_mean = util.masked_mean(matrix, mask, dim=-1, keepdim=True).data.numpy()
        assert_array_almost_equal(matrix_mean, numpy.array([[3.0], [-1.5]]))

        # Testing broadcast
        matrix = torch.FloatTensor([[[1.0, 2.0], [12.0, 3.0], [5.0, -1.0]],
                                    [[-1.0, -3.0], [-2.0, -0.5], [3.0, 8.0]]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]).unsqueeze(-1)
        matrix_mean = util.masked_mean(matrix, mask, dim=1).data.numpy()
        assert_array_almost_equal(matrix_mean, numpy.array([[3.0, 0.5], [-1.5, -1.75]]))
示例#2
0
    def test_masked_mean(self):
        # Testing the general masked 1D case.
        vector_1d = torch.FloatTensor([1.0, 12.0, 5.0])
        mask_1d = torch.FloatTensor([1.0, 0.0, 1.0])
        vector_1d_mean = util.masked_mean(vector_1d, mask_1d, dim=0).data.numpy()
        assert_array_almost_equal(vector_1d_mean, 3.0)

        # Testing if all masks are zero, the output will be arbitrary, but it should not be nan.
        vector_1d = torch.FloatTensor([1.0, 12.0, 5.0])
        mask_1d = torch.FloatTensor([0.0, 0.0, 0.0])
        vector_1d_mean = util.masked_mean(vector_1d, mask_1d, dim=0).data.numpy()
        assert not numpy.isnan(vector_1d_mean).any()

        # Testing batch value and batch masks
        matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
        matrix_mean = util.masked_mean(matrix, mask, dim=-1).data.numpy()
        assert_array_almost_equal(matrix_mean, numpy.array([3.0, -1.5]))

        # Testing keepdim for batch value and batch masks
        matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
        matrix_mean = util.masked_mean(matrix, mask, dim=-1, keepdim=True).data.numpy()
        assert_array_almost_equal(matrix_mean, numpy.array([[3.0], [-1.5]]))

        # Testing broadcast
        matrix = torch.FloatTensor([[[1.0, 2.0], [12.0, 3.0], [5.0, -1.0]],
                                    [[-1.0, -3.0], [-2.0, -0.5], [3.0, 8.0]]])
        mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]).unsqueeze(-1)
        matrix_mean = util.masked_mean(matrix, mask, dim=1).data.numpy()
        assert_array_almost_equal(matrix_mean, numpy.array([[3.0, 0.5], [-1.5, -1.75]]))
示例#3
0
文件: smbop.py 项目: hyukyu/SmBop
    def _encode_utt_schema(self, enc, offsets, relation, lengths):
        embedded_utterance_schema = self.emb_q(enc)

        (
            embedded_utterance_schema,
            embedded_utterance_schema_mask,
        ) = vec_utils.batched_span_select(embedded_utterance_schema, offsets)
        embedded_utterance_schema = masked_mean(
            embedded_utterance_schema,
            embedded_utterance_schema_mask.unsqueeze(-1),
            dim=-2,
        )

        relation_mask = (relation >= 0).float()  # TODO: fixme
        torch.abs(relation, out=relation)
        embedded_utterance_schema = self._emb_to_action_dim(
            embedded_utterance_schema)
        enriched_utterance_schema = self._schema_encoder(
            embedded_utterance_schema, relation.long(), relation_mask)

        utterance_schema, utterance_schema_mask = vec_utils.batched_span_select(
            enriched_utterance_schema, lengths)
        utterance, schema = torch.split(utterance_schema, 1, dim=1)
        utterance_mask, schema_mask = torch.split(utterance_schema_mask,
                                                  1,
                                                  dim=1)
        utterance_mask = torch.squeeze(utterance_mask, 1)
        schema_mask = torch.squeeze(schema_mask, 1)
        embedded_utterance = torch.squeeze(utterance, 1)
        schema = torch.squeeze(schema, 1)
        return schema, schema_mask, embedded_utterance, utterance_mask
示例#4
0
文件: editor.py 项目: isomap/factedit
    def _init_state(self, triples: Dict[str, torch.LongTensor],
                    predicate: Dict[str, torch.LongTensor],
                    draft: Dict[str, torch.LongTensor],
                    triple_ids: torch.LongTensor) -> Dict[str, torch.Tensor]:
        emb_pred = util.masked_mean(
            self.EMB(predicate),
            util.get_text_field_mask(
                predicate,
                num_wrapping_dims=1,
            ).unsqueeze(-1), 2)
        emb_triple = self.EMB(triples)
        triple_mask = util.get_text_field_mask(triples)
        flat_triples = torch.cat((emb_triple.flatten(2, 3), emb_pred), dim=-1)

        encoded_triples = self.FACT_ENCODER(flat_triples)

        emb_draft = self.EMB(draft)
        draft_mask = util.get_text_field_mask(draft)
        end_point = (draft_mask.sum(dim=1) - 1)
        encoded_draft = self.BUFFER(emb_draft, draft_mask)

        return {
            "draft_mask": draft_mask,
            "triple_mask": triple_mask,
            "end_point": end_point,
            "encoded_triple": encoded_triples,
            "encoded_draft": encoded_draft,
            "triple_tokens": triples["tokens"][:, :, -1],
            "triple_token_ids": triple_ids
        }
示例#5
0
    def forward(self, **kwargs) -> torch.FloatTensor:
        mask = kwargs['mask']
        embedded_text = kwargs['embedded_text']
        encoded_output = self._architecture(embedded_text, mask)
        encoded_repr = []
        for aggregation in self._aggregations:
            if aggregation == "meanpool":
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoded_output * broadcast_mask
                encoded_text = masked_mean(context_vectors,
                                           broadcast_mask,
                                           dim=1,
                                           keepdim=False)
            elif aggregation == 'maxpool':
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoded_output * broadcast_mask
                encoded_text = masked_max(context_vectors,
                                          broadcast_mask,
                                          dim=1)
            elif aggregation == 'final_state':
                is_bi = self._architecture.is_bidirectional()
                encoded_text = get_final_encoder_states(encoded_output,
                                                        mask,
                                                        is_bi)
            elif aggregation == 'attention':
                alpha = self._attention_layer(encoded_output)
                alpha = masked_log_softmax(alpha, mask.unsqueeze(-1), dim=1).exp()
                encoded_text = alpha * encoded_output
                encoded_text = encoded_text.sum(dim=1)
            else:
                raise ConfigurationError(f"{aggregation} aggregation not available.")
            encoded_repr.append(encoded_text)

        encoded_repr = torch.cat(encoded_repr, 1)
        return encoded_repr
示例#6
0
    def _encode_definition(
            self, definition: Dict[str,
                                   torch.Tensor]) -> Dict[str, torch.Tensor]:
        # [batch_size, seq_len]
        definition_mask = util.get_text_field_mask(definition)
        # [batch_size, seq_len, emb_dim]
        embedded_definition = self.text_embedder(definition)

        # either [batch_size, emb_dim] or [batch_size, seq_len, emb_dim]
        encoded_definition = self.definition_encoder(embedded_definition,
                                                     definition_mask)
        # if len(encoded_definition.size()) == 3:
        if self.definition_pooling == 'last':
            # [batch_size, emb_dim]
            encoded_definition = util.get_final_encoder_states(
                encoded_definition, definition_mask)
        elif self.definition_pooling == 'max':
            # encoded_definition = F.adaptive_max_pool1d(encoded_definition.transpose(1, 2), 1).squeeze(2)
            encoded_definition = util.masked_max(encoded_definition,
                                                 definition_mask.unsqueeze(2),
                                                 dim=1)
        elif self.definition_pooling == 'mean':
            # encoded_definition = F.adaptive_avg_pool1d(encoded_definition.transpose(1, 2), 1).squeeze(2)
            encoded_definition = util.masked_mean(encoded_definition,
                                                  definition_mask.unsqueeze(2),
                                                  dim=1)
        elif self.definition_pooling == 'self-attentive':
            self_attentive_logits = self.self_attentive_pooling_projection(
                encoded_definition).squeeze(2)
            self_weights = util.masked_softmax(self_attentive_logits,
                                               definition_mask)
            encoded_definition = util.weighted_sum(encoded_definition,
                                                   self_weights)
        # [batch_size, emb_dim]
        definition_embedding = self.definition_feedforward(encoded_definition)

        # [batch_size, vocab_size(num_class)]
        definition_logits = self.definition_decoder(definition_embedding)
        # [batch_size, seq_len, vocab_size]
        sequence_definition_logits = definition_logits.unsqueeze(1).repeat(
            1, definition_mask.size(1), 1)

        # ``average`` can be None, "batch", or "token"
        # loss for ``average==None`` is a vector of shape (batch_size,); otherwise, a scalar
        targets = definition['tokens'].clone()
        if self.limited_word_vocab_size is not None:
            targets[targets >= self.limited_word_vocab_size] = self._oov_index
        cross_entropy_loss = util.sequence_cross_entropy_with_logits(
            sequence_definition_logits,
            targets,
            # definition['tokens'],
            weights=definition_mask,
            average='token')

        return {
            "definition_embedding": definition_embedding,
            "cross_entropy_loss": cross_entropy_loss
        }
示例#7
0
 def forward(self, # pylint: disable=arguments-differ
             premises_relevance_logits: torch.Tensor,
             premises_presence_mask: torch.Tensor,
             relevance_presence_mask: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument
     premises_relevance_logits = replace_masked_values(premises_relevance_logits, premises_presence_mask, -1e10)
     binary_losses = self._loss(premises_relevance_logits, relevance_presence_mask)
     coverage_losses = masked_mean(binary_losses, premises_presence_mask, dim=1)
     coverage_loss = coverage_losses.mean()
     return coverage_loss
示例#8
0
    def pred(self, ws, ctxs):  # ws : no use
        # ctxs : B,C,S
        ctxs = torch.stack(ctxs)
        x = self.wvec[ctxs].cuda()  # B,C,S,D
        mask = (ctxs != -1).cuda()  # B,C,S
        B, C, S, D = x.shape

        x = x.reshape(B * C, S, D)
        mask = mask.reshape(B * C, S)
        x = self.posenc(x)  # B*C,S,D
        x = self.ctxenc(x, mask)  # B*C,S,D
        x = masked_mean(x, mask[:, :, None], dim=-2)  # B*C,D

        x = x.reshape(B, C, D)
        mask = mask.reshape(B, C, S).any(-1)  # B,C
        x = self.ctxagg(x, mask)  # B,C,D
        x = masked_mean(x, mask[:, :, None], dim=-2)  # B,D
        return x
示例#9
0
    def _compute_answer(self, premise_memory: torch.Tensor,
                        hypothesis_memory: torch.Tensor,
                        premise_mask: torch.Tensor,
                        hypothesis_mask: torch.Tensor) -> torch.Tensor:
        batch_size = premise_memory.size(0)
        num_labels = self._output_logit.get_output_dim()

        # Shape: (batch_size, hypothesis_length)
        hypothesis_attention = util.masked_softmax(
            self._answer_attention(hypothesis_memory).squeeze(),
            hypothesis_mask,
        )
        # Shape: (batch_size, embedding_dim)
        answer_state = util.weighted_sum(hypothesis_memory,
                                         hypothesis_attention)

        label_prob_steps: torch.Tensor = answer_state.new_zeros(
            (batch_size, num_labels, self._answer_steps))
        for step in range(self._answer_steps):
            # Shape: (batch_size, premise_length)
            premise_attention = self._answer_bilinear(answer_state,
                                                      premise_memory,
                                                      premise_mask)
            # Shape: (batch_size, embedding_dim)
            cell_input = util.weighted_sum(premise_memory, premise_attention)

            answer_state = self._answer_gru_cell(cell_input, answer_state)

            output_hidden = torch.cat([
                answer_state,
                cell_input,
                (answer_state - cell_input).abs(),
                answer_state * cell_input,
            ],
                                      dim=-1)
            label_logits = self._output_logit(
                self._output_feedforward(output_hidden))
            label_prob_steps[:, :, step] = label_logits.softmax(-1)

        if self.training and self._dropout:
            # stochastic prediction dropout
            binary_mask = (torch.rand(
                (batch_size, self._answer_steps)) > self._dropout.p).to(
                    label_prob_steps.device)
            label_probs = util.masked_mean(label_prob_steps,
                                           binary_mask.float().unsqueeze(1),
                                           dim=2)
            label_probs = util.replace_masked_values(
                label_probs,
                binary_mask.sum(1, keepdim=True).bool().float(),
                1.0 / num_labels)
        else:
            label_probs = label_prob_steps.mean(2)

        return label_probs
示例#10
0
文件: editor.py 项目: isomap/factedit
    def _decoder_init(self, state):
        mean_draft = util.masked_mean(state["encoded_draft"],
                                      state["draft_mask"].unsqueeze(-1), 1)
        mean_triple = util.masked_mean(state["encoded_triple"],
                                       state["triple_mask"].unsqueeze(-1), 1)
        concatenated = torch.cat((mean_draft, mean_triple), dim=-1)
        batch_size = state["draft_mask"].size(0)

        zeros = mean_draft.new_zeros((batch_size, self.decoder_size))
        state["stream_hidden"], state["stream_context"] = self.U(
            concatenated), zeros
        state["draft_pointer"] = state["draft_mask"].new_ones((batch_size, ))

        action_mask = mean_draft.new_ones((batch_size, self.vocab_size))
        action_mask[:, self.PAD] = 0
        action_mask[:, self.END] = 0

        state["action_mask"] = action_mask

        return state
示例#11
0
 def _get_summary_of_encoder_outputs(self, encoder_outputs, source_mask):
     # This returns last final encoder output in case of RNN encoders,
     # and mean of the outputs in case of other encoders
     if type(self._encoder) == PytorchSeq2SeqWrapper:
         summary = util.get_final_encoder_states(
             encoder_outputs, source_mask, self._encoder.is_bidirectional())
     else:
         summary = masked_mean(encoder_outputs,
                               source_mask.unsqueeze(-1).to(
                                   encoder_outputs.device),
                               dim=1,
                               keepdim=False)
     return summary
示例#12
0
def pool(vector: torch.Tensor,
         mask: torch.Tensor,
         dim: int,
         pooling: str,
         is_bidirectional: bool) -> torch.Tensor:
    if pooling == "max":
        return masked_max(vector, mask, dim)
    elif pooling == "mean":
        return masked_mean(vector, mask, dim)
    elif pooling == "sum":
        return torch.sum(vector, dim)
    elif pooling == "final":
        return get_final_encoder_states(vector, mask, is_bidirectional)
    else:
        raise ValueError(f"'{pooling}' is not a valid pooling operation.")
示例#13
0
文件: model.py 项目: EntilZha/qb-bert
 def forward(self,
             text: Dict[str, torch.LongTensor],
             metadata=None,
             page: torch.IntTensor = None):  # pylint: disable=arguments-differ
     input_ids: torch.LongTensor = text["text"]
     # Grab the representation of CLS token, which is always first
     if self._pool == "cls":
         bert_emb = self._bert(input_ids)[:, 0, :]
     elif self._pool == "mean":
         mask = (input_ids != 0).long()[:, :, None]
         bert_seq_emb = self._bert(input_ids)
         bert_emb = util.masked_mean(bert_seq_emb, mask, dim=1)
     else:
         raise ValueError("Invalid config")
     return self._hidden_to_output(bert_emb, page)
示例#14
0
def find_max_window(p_prob, mask, offset):
    batch_size = p_prob.size(0)
    out_idx = []
    for b in range(batch_size):
        mean_prob = allenutil.masked_mean(p_prob[b], mask[b], dim=-1)
        max_idx = np.argmax(p_prob[b].detach().cpu().numpy())
        # There are many possible ways to determine max_id, the above method is simply choosing the highest probability.
        # But you can use some other ideas, like calculating the total probability of a 3-gram window instead of each token.
        # max_idx = find_max_ind(p_prob[b])
        max_value = p_prob[b][max_idx]
        start = find_surrounding_with_max(p_prob[b], max_idx,
                                          max(4 * mean_prob, 0.0), 'L')
        end = find_surrounding_with_max(p_prob[b], max_idx,
                                        max(4 * mean_prob, 0.0), 'R')
        start += offset[b].tolist()
        end += offset[b].tolist()
        out_idx.append([start, end])
    return out_idx
示例#15
0
    def pool_node_embeddings(self, last_layers, masks, gdata, batch_num_nodes):
        """
        Convert wordpiece embeddings into word (i.e. node) embeddings using the alignment in
        wpidx2graphid = gdata['wpidx2graphid']

        Parameters:
            g_data: dictinoary with values having shape:
                (bsz, ...)
            masks: (bsz, max_sent_pair_len)
            last_layers: (bsz, max_sent_pair_len, emb_dim)

        Returns:
            node_embs: (bsz, max_num_nodes, emb_dim)
            node_embeddings_mask: (bsz, max_num_nodes)
        """
        wpidx2graphid = gdata['wpidx2graphid']  # (bsz, max_sent_len, max_n_nodes)
        device = last_layers.device
        bsz, max_sent_len, max_n_nodes = wpidx2graphid.shape
        emb_dim = last_layers.shape[-1]
        assert max(batch_num_nodes) == wpidx2graphid.shape[-1]

        # the following logic happens to work if the graph is empty, in which case its sentence_end is guaranteed to be 1 (exclusive)
        masks_cumsum = masks.cumsum(1)
        sentence_starts = first_true_idx(masks, 1, masks_cumsum)
        sentence_ends = last_true_idx(masks, 1, masks_cumsum) + 1  # exclusive
        max_sentence_len = (sentence_ends - sentence_starts).max()

        # we're using a for loop here since only doing rolling across the batch dimension shouldn't be very expensive
        # that said, can we do it without a loop?
        rolled_last_layers = torch.stack([last_layer.roll(-sentence_start.item(), dims=0) for last_layer, sentence_start in zip(last_layers, sentence_starts)])
        segmented_last_layers = rolled_last_layers[:, :max_sentence_len, :]  # (bsz, max_sent_len, emb_dim)
        assert segmented_last_layers.shape[:2] == wpidx2graphid.shape[:2]

        # (bsz, max_sent_len, max_n_nodes, emb_dim)
        expanded_wpidx2graphid = wpidx2graphid.unsqueeze(-1).expand(-1, -1, -1, emb_dim)
        expanded_segmented_last_layers = segmented_last_layers.unsqueeze(2).expand(-1, -1, max_n_nodes, -1)

        # (bsz, max_n_nodes, emb_dim)
        node_embeddings = masked_mean(expanded_segmented_last_layers, expanded_wpidx2graphid, 1)

        node_embeddings = torch.where(expanded_wpidx2graphid.any(1), node_embeddings, torch.tensor(0., device=device))  # some nodes don't have corresponding wordpieces
        node_embeddings_mask = torch.arange(max(batch_num_nodes), device=device).expand(bsz, -1) < torch.tensor(batch_num_nodes, dtype=torch.long, device=device).unsqueeze(1)

        return node_embeddings, node_embeddings_mask
示例#16
0
    def forward(self,
                document,
                query=None,
                label=None,
                metadata=None,
                rationale=None,
                **kwargs) -> Dict[str, Any]:
        #pylint: disable=arguments-differ

        bert_document = self.combine_document_query(document, query)

        last_hidden_states, _ = self._bert_model(
            bert_document["bert"]["wordpiece-ids"],
            attention_mask=bert_document["bert"]["wordpiece-mask"],
            position_ids=bert_document["bert"]["position-ids"],
            token_type_ids=bert_document["bert"]["type-ids"],
        )

        token_embeddings, span_mask = generate_embeddings_for_pooling(
            last_hidden_states,
            bert_document["bert"]['document-starting-offsets'],
            bert_document["bert"]['document-ending-offsets'])

        token_embeddings = util.masked_mean(token_embeddings,
                                            span_mask.unsqueeze(-1),
                                            dim=2)
        token_embeddings = token_embeddings * bert_document['bert'][
            "mask"].unsqueeze(-1)

        logits = torch.nn.functional.softplus(
            self._classification_layer(self._dropout(token_embeddings)))

        a, b = logits[:, :, 0], logits[:, :, 1]
        mask = bert_document['bert']['mask']

        output_dict = {}
        output_dict["a"] = a * mask
        output_dict["b"] = b * mask
        output_dict['mask'] = mask
        output_dict['wordpiece-to-token'] = bert_document['bert'][
            'wordpiece-to-token']
        return output_dict
示例#17
0
    def _average_image_features(
            self,
            image_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Perform mean pooling of bottom-up image features, while taking care of variable
        ``num_boxes`` in case of adaptive features.

        Extended Summary
        ----------------
        For a single training/evaluation instance, the image features remain the same from first
        time-step to maximum decoding steps. To keep a clean API, we use LRU cache -- which would
        maintain a cache of last 10 return values because on call signature, and not actually
        execute itself if it is called with the same image features seen at least once in last
        10 calls. This saves some computation.

        Parameters
        ----------
        image_features: torch.Tensor
            A tensor of shape ``(batch_size, num_boxes, image_feature_size)``. ``num_boxes`` for
            each instance in a batch might be different. Instances with lesser boxes are padded
            with zeros up to ``num_boxes``.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Averaged image features of shape ``(batch_size, image_feature_size)`` and a binary
            mask of shape ``(batch_size, num_boxes)`` which is zero for padded features.
        """

        # shape: (batch_size, num_boxes)
        image_features_mask = torch.sum(torch.abs(image_features), dim=-1) > 0

        # shape: (batch_size, image_feature_size)
        averaged_image_features = masked_mean(
            image_features, image_features_mask.unsqueeze(-1), dim=1)

        return averaged_image_features, image_features_mask
示例#18
0
    def forward(self, document, rationale=None) -> Dict[str, Any]:
        embedded_text = self._text_field_embedder(document)
        mask = util.get_text_field_mask(document).float()

        embedded_text = self._dropout(
            self._seq2seq_encoder(embedded_text, mask=mask))
        embedded_text = self._feedforward_encoder(embedded_text)

        logits = self._classification_layer(embedded_text).squeeze(-1)
        probs = torch.sigmoid(logits)

        output_dict = {}

        predicted_rationale = (probs > 0.5).long()
        output_dict['predicted_rationale'] = predicted_rationale * mask
        output_dict["prob_z"] = probs * mask

        class_probs = torch.cat([1 - probs.unsqueeze(-1),
                                 probs.unsqueeze(-1)],
                                dim=-1)

        average_rationale_length = util.masked_mean(
            output_dict['predicted_rationale'], mask, dim=-1).mean()
        self._rationale_length(average_rationale_length.item())

        if rationale is not None:
            rationale_loss = F.binary_cross_entropy_with_logits(
                logits, rationale.float(), weight=mask)
            output_dict['rationale_supervision_loss'] = rationale_loss
            output_dict['gold_rationale'] = rationale * mask
            self._rationale_f1_metric(predictions=class_probs,
                                      gold_labels=rationale,
                                      mask=mask)
            self._rationale_supervision_loss(rationale_loss.item())

        return output_dict
示例#19
0
    def forward(self,
                definition: Dict[str, torch.LongTensor],
                word: Dict[str, torch.LongTensor] = None,
                word_to_definition: torch.Tensor = None,
                **kwargs) -> Dict[str, torch.Tensor]:

        output_dict = {}
        output_dict.update(self._encode_definition(definition))
        output_dict['loss'] = self.alpha * output_dict['cross_entropy_loss']

        if self.beta > 0 and word is not None:
            # [batch_size, seq_len(1)]
            word_in_definition_mask = (word_to_definition !=
                                       self._oov_index).float()
            # [batch_size]
            word_in_definition_mask = word_in_definition_mask.squeeze(dim=1)

            # [batch_size, seq_len(1), emb_dim]
            embedded_word = self.text_embedder({'tokens': word_to_definition})
            # [batch_size, emb_dim]
            embedded_word = embedded_word.squeeze(dim=1)

            mse = self.pdist(output_dict['definition_embedding'],
                             embedded_word)
            consistency_loss = util.masked_mean(mse,
                                                word_in_definition_mask,
                                                dim=0)
            output_dict['consistency_loss'] = consistency_loss

            output_dict['loss'] += self.beta * output_dict['consistency_loss']

            for metric in self.metrics.values():
                metric(output_dict['definition_embedding'], embedded_word,
                       word_in_definition_mask)

        return output_dict
    def forward(self,
                document,
                query=None,
                label=None,
                metadata=None,
                rationale=None) -> Dict[str, Any]:
        # pylint: disable=arguments-differ

        generator_dict = self._generator(document, query, label)
        mask = generator_dict["mask"]

        assert "a" in generator_dict
        assert "b" in generator_dict

        a, b = generator_dict['a'], generator_dict['b']
        a = a.clamp(1e-6, 100.)  # extreme values could result in NaNs
        b = b.clamp(1e-6, 100.)  # extreme values could result in NaNs

        output_dict = {}

        sampler = HardKuma([a, b],
                           support=[
                               self.support[0].to(a.device),
                               self.support[1].to(b.device)
                           ])
        generator_dict['predicted_rationale'] = (sampler.mean() >
                                                 0.5).long() * mask

        if self.prediction_mode or not self.training:
            if self._rationale_extractor is None:
                # We constrain rationales to be 0 or 1 strictly. See Pruthi et al
                # for pathologies when this is not the case.
                sample_z = (sampler.mean() > 0.5).long() * mask
            else:
                prob_z = sampler.mean()
                sample_z = self._rationale_extractor.extract_rationale(
                    prob_z, metadata, as_one_hot=True)
                output_dict[
                    "rationale"] = self._rationale_extractor.extract_rationale(
                        prob_z, metadata, as_one_hot=False)
                sample_z = torch.Tensor(sample_z).to(prob_z.device).float()
        else:
            sample_z = sampler.sample()

        sample_z = sample_z * mask

        # Because BERT is BERT
        wordpiece_to_token = generator_dict['wordpiece-to-token']
        wtt0 = torch.where(wordpiece_to_token == -1,
                           torch.tensor([0]).to(wordpiece_to_token.device),
                           wordpiece_to_token)
        wordpiece_sample = util.batched_index_select(sample_z.unsqueeze(-1),
                                                     wtt0)
        wordpiece_sample[wordpiece_to_token.unsqueeze(-1) == -1] = 1.0

        def scale_embeddings(module, input, output):
            output = output * wordpiece_sample
            return output

        hook = self._encoder.embedding_layers[0].register_forward_hook(
            scale_embeddings)

        encoder_dict = self._encoder(
            document=document,
            query=query,
            label=label,
            metadata=metadata,
        )

        hook.remove()

        loss = 0.0

        if label is not None:
            assert "loss" in encoder_dict

            base_loss = F.cross_entropy(encoder_dict["logits"], label)  # (B,)

            lasso_loss = ((1 - sampler.pdf(0.)) * mask).sum(1)
            lengths = mask.sum(1)

            lasso_loss = lasso_loss / (lengths + 1e-9)

            censored_lasso_loss = F.relu(lasso_loss / (lengths + 1e-9) -
                                         self._desired_length)
            censored_lasso_loss = censored_lasso_loss.mean()

            # diff = (sample_z[:, 1:] - sample_z[:, :-1]).abs()
            # mask_last = mask[:, :-1]
            # fused_lasso_loss = diff.sum(-1) / mask_last.sum(-1)

            self._loss_tracks["_lasso_loss"](lasso_loss.mean().item())
            self._loss_tracks["_censored_lasso_loss"](
                censored_lasso_loss.mean().item())
            # self._loss_tracks["_fused_lasso_loss"](fused_lasso_loss.mean().item())
            self._loss_tracks["_base_loss"](base_loss.mean().item())

            generator_loss = self._reg_loss_lambda * censored_lasso_loss

            self._loss_tracks["_generator_loss"](generator_loss.mean().item())

            loss += (base_loss + generator_loss).mean()

        output_dict["probs"] = encoder_dict["probs"]
        output_dict["predicted_labels"] = encoder_dict["predicted_labels"]

        output_dict["loss"] = loss
        output_dict["gold_labels"] = label
        output_dict["metadata"] = metadata

        output_dict["predicted_rationale"] = generator_dict[
            "predicted_rationale"]

        self._loss_tracks["_rat_length"](util.masked_mean(
            generator_dict["predicted_rationale"], mask, dim=-1).mean().item())

        self._call_metrics(output_dict)

        return output_dict
示例#21
0
    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        len_q=None,
        bs_seperator_index=None,
        s_first=True,
        max_q_length=30,
        max_s_length=200,
        max_b_length=400,
        sp_relevance=None,
        sp_tp_polarity=None,
        tp_relevance=None,
        object1_label=None,
        object2_label=None,
        SP_Object1_label=None,
        SP_Object2_label=None,
        SP_Back_label=None,
        TP_Back_label=None,
    ):

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        torch.set_printoptions(precision=8, sci_mode=False)
        # Step1 Encoder, we need to split Background , situation and question from the whole contextual representation, then pad them into a fixed length.
        sequence_output = outputs[0]
        batch_size = sequence_output.size(0)
        max_seq_length = sequence_output.size(1)
        hidden_size = sequence_output.size(2)
        device = outputs[0].device
        padded_HQ = torch.zeros([batch_size, max_q_length,
                                 hidden_size]).to(device)
        padded_mask_HQ = torch.zeros([batch_size, max_q_length]).to(device)
        padded_HS = torch.zeros([batch_size, max_s_length,
                                 hidden_size]).to(device)
        padded_mask_HS = torch.zeros([batch_size, max_s_length]).to(device)
        padded_HB = torch.zeros([batch_size, max_b_length,
                                 hidden_size]).to(device)
        padded_mask_HB = torch.zeros([batch_size, max_b_length]).to(device)
        s_inds = []
        for ind in range(batch_size):
            try:
                mask_ind = (attention_mask[ind] == 0).nonzero()[0][0] - 1
            except:
                mask_ind = max_seq_length - 1
            HQ, padded_h_q, padded_mask_q = pad_hiddenstate(
                sequence_output[ind, 1:1 + len_q[ind], :], max_q_length)
            if s_first:
                HS, padded_h_s, padded_mask_s = pad_hiddenstate(
                    sequence_output[ind, len_q[ind] +
                                    3:bs_seperator_index[ind] + 1, :],
                    max_s_length)
                HB, padded_h_b, padded_mask_b = pad_hiddenstate(
                    sequence_output[ind,
                                    bs_seperator_index[ind] + 1:mask_ind, :],
                    max_b_length)
            else:
                HB, padded_h_b, padded_mask_b = pad_hiddenstate(
                    sequence_output[ind, len_q[ind] +
                                    3:bs_seperator_index[ind] + 1, :],
                    max_s_length)
                HS, padded_h_s, padded_mask_s = pad_hiddenstate(
                    sequence_output[ind,
                                    bs_seperator_index[ind] + 1:mask_ind, :],
                    max_b_length)
            s_inds.append(bs_seperator_index[ind] - 2 - len_q[ind])
            padded_HQ[ind, :, :] = padded_h_q
            padded_mask_HQ[ind, :] = padded_mask_q
            padded_HS[ind, :, :] = padded_h_s
            padded_mask_HS[ind, :] = padded_mask_s
            padded_HB[ind, :, :] = padded_h_b
            padded_mask_HB[ind, :] = padded_mask_b

        # auxiliary labels also need padding.
        padded_O1 = torch.zeros([batch_size, max_s_length]).to(device)
        padded_O2 = torch.zeros([batch_size, max_s_length]).to(device)
        padded_SP_o1 = torch.zeros([batch_size, max_s_length]).to(device)
        padded_SP_o2 = torch.zeros([batch_size, max_s_length]).to(device)
        padded_SP = torch.zeros([batch_size, max_b_length]).to(device)
        padded_TP = torch.zeros([batch_size, max_b_length]).to(device)
        for ind in range(batch_size):
            try:
                mask_ind = (attention_mask[ind] == 0).nonzero()[0][0]
            except:
                mask_ind = max_seq_length - 1
            if s_first:
                _, padded_o1 = pad_supervison_label(
                    object1_label[ind,
                                  len_q[ind] + 3:bs_seperator_index[ind] + 1],
                    max_s_length) if object1_label is not None else [
                        None, None
                    ]
                _, padded_o2 = pad_supervison_label(
                    object2_label[ind,
                                  len_q[ind] + 3:bs_seperator_index[ind] + 1],
                    max_s_length) if object2_label is not None else [
                        None, None
                    ]
                _, padded_sp_o1 = pad_supervison_label(
                    SP_Object1_label[ind, len_q[ind] +
                                     3:bs_seperator_index[ind] + 1],
                    max_s_length) if SP_Object1_label is not None else [
                        None, None
                    ]
                _, padded_sp_o2 = pad_supervison_label(
                    SP_Object2_label[ind, len_q[ind] +
                                     3:bs_seperator_index[ind] + 1],
                    max_s_length) if SP_Object2_label is not None else [
                        None, None
                    ]
                _, padded_sp = pad_supervison_label(
                    SP_Back_label[ind, bs_seperator_index[ind] + 1:mask_ind],
                    max_b_length) if SP_Back_label is not None else [
                        None, None
                    ]
                _, padded_tp = pad_supervison_label(
                    TP_Back_label[ind, bs_seperator_index[ind] + 1:mask_ind],
                    max_b_length) if TP_Back_label is not None else [
                        None, None
                    ]

            padded_O1[
                ind, :] = padded_o1 if padded_o1 is not None else padded_O1[
                    ind, :]
            padded_O2[
                ind, :] = padded_o2 if padded_o2 is not None else padded_O2[
                    ind, :]
            padded_SP_o1[
                ind, :] = padded_sp_o1 if padded_sp_o1 is not None else padded_SP_o1[
                    ind, :]
            padded_SP_o2[
                ind, :] = padded_sp_o2 if padded_sp_o2 is not None else padded_SP_o2[
                    ind, :]
            padded_SP[
                ind, :] = padded_sp if padded_sp is not None else padded_SP[
                    ind, :]
            padded_TP[
                ind, :] = padded_tp if padded_tp is not None else padded_TP[
                    ind, :]

        # **************************** STEP 2 Find OBJECT/World ****************************
        # [b,n,d] -> [b,n,1]
        ps_object1 = self.find_object1(padded_HS)
        ps_object2 = self.find_object2(padded_HS)
        ps_object1 = allenutil.masked_softmax(ps_object1.squeeze(),
                                              padded_mask_HS,
                                              memory_efficient=True)
        ps_object2 = allenutil.masked_softmax(ps_object2.squeeze(),
                                              padded_mask_HS,
                                              memory_efficient=True)

        #  ****************************STEP 3 Find TP/Effect in B ****************************
        # [b,m,d] -> [b,m,1]
        pb_TP = self.find_TP(padded_HB)
        pb_TP = allenutil.masked_softmax(pb_TP.squeeze(),
                                         padded_mask_HB,
                                         memory_efficient=True)

        #  ****************************STEP 4 Relocate TP/Effect to SP/Cause ****************************
        mean_HS = allenutil.masked_mean(padded_HS,
                                        padded_mask_HS.unsqueeze(-1),
                                        dim=1)
        relocate_bb_similarity_matrix = self.bb_matrix_attention(
            torch.add(mean_HS.unsqueeze(1), padded_HB), padded_HB)
        b2b_attention_matrix = allenutil.masked_softmax(
            relocate_bb_similarity_matrix,
            padded_mask_HB,
            memory_efficient=True,
            dim=-1)
        pb_SP = torch.sum(torch.mul(pb_TP.unsqueeze(-1), b2b_attention_matrix),
                          dim=1)

        # if we dont have labels, we can comment out this two lines
        padded_TP_normal = torch.nn.functional.normalize(padded_TP,
                                                         p=1,
                                                         dim=-1)
        pb_SP_gold = torch.sum(torch.mul(padded_TP_normal.unsqueeze(-1),
                                         b2b_attention_matrix),
                               dim=1)

        # *************************Step 5: Find SP/cause for object/world 1 and object/world 2 ****************************
        # Explained in Comparison module, treat two worlds as masks.

        s2b_similarity_matrix = self.bs_bilinear_imilairty(
            padded_HS, padded_HB)
        s2b_similarity_attention = allenutil.masked_softmax(
            s2b_similarity_matrix,
            padded_mask_HB,
            memory_efficient=True,
            dim=1)
        b2s_similarity_matrix = torch.transpose(s2b_similarity_matrix, 1, 2)
        ps_guided_SP = torch.sum(torch.mul(pb_SP.unsqueeze(-1),
                                           b2s_similarity_matrix),
                                 dim=1)
        mask_score_object1 = ps_object1
        mask_score_object2 = ps_object2
        ps_SP_object1 = torch.mul(mask_score_object1, ps_guided_SP)
        ps_SP_object1 = allenutil.masked_softmax(ps_SP_object1,
                                                 padded_mask_HS,
                                                 memory_efficient=True,
                                                 dim=-1)
        ps_SP_object2 = torch.mul(mask_score_object2, ps_guided_SP)
        ps_SP_object2 = allenutil.masked_softmax(ps_SP_object2,
                                                 padded_mask_HS,
                                                 memory_efficient=True,
                                                 dim=-1)

        #  gold label if we have.
        padded_SP_normal = torch.nn.functional.normalize(padded_SP,
                                                         p=1,
                                                         dim=-1)
        ps_guided_SP_gold = torch.sum(torch.mul(padded_SP_normal.unsqueeze(-1),
                                                b2s_similarity_matrix),
                                      dim=1)
        mask_score_object1_gold = torch.nn.functional.normalize(padded_O1 +
                                                                0.01,
                                                                p=1,
                                                                dim=-1)
        mask_score_object2_gold = torch.nn.functional.normalize(padded_O2 +
                                                                0.01,
                                                                p=1,
                                                                dim=-1)
        ps_SP_object1_gold = torch.mul(mask_score_object1_gold,
                                       ps_guided_SP_gold)
        ps_SP_object1_gold = allenutil.masked_softmax(ps_SP_object1_gold,
                                                      padded_mask_HS,
                                                      memory_efficient=True,
                                                      dim=-1)
        ps_SP_object2_gold = torch.mul(mask_score_object2_gold,
                                       ps_guided_SP_gold)
        ps_SP_object2_gold = allenutil.masked_softmax(ps_SP_object2_gold,
                                                      padded_mask_HS,
                                                      memory_efficient=True,
                                                      dim=-1)

        #  ****************************Step 6 relevance/comparison check ****************************
        summed_HB_weighted_pb_SP = torch.matmul(
            pb_SP.unsqueeze(1), padded_HB)  # 1XMXMXD => [B,1,D]
        summed_HS_weighted_ps_SP_o1 = torch.matmul(
            ps_SP_object1.unsqueeze(1), padded_HS)  # 1XNXNXD => [B,1,D]
        summed_HS_weighted_ps_SP_o2 = torch.matmul(
            ps_SP_object2.unsqueeze(1), padded_HS)  # 1XNXNXD => [B,1,D]
        p_relevance_logits = self.rel_SPo1_SPo2(summed_HB_weighted_pb_SP,
                                                summed_HS_weighted_ps_SP_o1,
                                                summed_HS_weighted_ps_SP_o2)
        normal_p_relevance_logits = torch.nn.functional.normalize(
            p_relevance_logits, p=1)
        p_relevance = torch.softmax(normal_p_relevance_logits, dim=-1)

        #  gold label if we have.
        padded_SP_o1_normal = torch.nn.functional.normalize(padded_SP_o1,
                                                            p=1,
                                                            dim=-1)
        padded_SP_o2_normal = torch.nn.functional.normalize(padded_SP_o2,
                                                            p=1,
                                                            dim=-1)
        GOLD_summed_HB_weighted_pb_SP = torch.matmul(
            padded_SP_normal.unsqueeze(1).type(dtype=torch.float),
            padded_HB)  # 1XMXMXD => [B,1,D]
        GOLD_summed_HS_weighted_ps_SP_o1 = torch.matmul(
            padded_SP_o1_normal.unsqueeze(1).type(dtype=torch.float),
            padded_HS)  # 1XNXNXD => [B,1,D]
        GOLD_summed_HS_weighted_ps_SP_o2 = torch.matmul(
            padded_SP_o2_normal.unsqueeze(1).type(dtype=torch.float),
            padded_HS)  # 1XNXNXD => [B,1,D]
        p_relevance_logits_gold = self.rel_SPo1_SPo2(
            GOLD_summed_HB_weighted_pb_SP, GOLD_summed_HS_weighted_ps_SP_o1,
            GOLD_summed_HS_weighted_ps_SP_o2)
        normal_p_relevance_logits_gold = torch.nn.functional.normalize(
            p_relevance_logits_gold, p=1)
        p_relevance_gold = torch.softmax(normal_p_relevance_logits_gold,
                                         dim=-1)

        #  ****************************Step 7 relation classification/polarity ****************************
        summed_HB_weighted_pb_TP = torch.matmul(pb_TP.unsqueeze(1), padded_HB)
        summed_HB_weighted_TP_SP = torch.cat(
            (summed_HB_weighted_pb_SP, summed_HB_weighted_pb_TP),
            dim=-1).squeeze(1)
        p_polarity_logits = self.pol_TP_SP(summed_HB_weighted_TP_SP)
        p_polarity = torch.softmax(p_polarity_logits, dim=-1)
        p_polarity_negative = p_polarity[:, 0]
        p_polarity_positive = p_polarity[:, 1]

        #  gold label if we have.
        summed_HB_weighted_pb_TP_gold = torch.matmul(
            padded_TP_normal.unsqueeze(1), padded_HB)
        summed_HB_weighted_TP_SP_gold = torch.cat(
            (GOLD_summed_HB_weighted_pb_SP, summed_HB_weighted_pb_TP_gold),
            dim=-1).squeeze(1)
        p_polarity_logits_gold = self.pol_TP_SP(summed_HB_weighted_TP_SP_gold)
        p_polarity_gold = torch.softmax(p_polarity_logits_gold, dim=-1)
        p_polarity_negative_gold = p_polarity_gold[:, 0]
        p_polarity_positive_gold = p_polarity_gold[:, 1]

        # ****************************Step 8 Reasoning ****************************
        object1 = p_relevance[:, 0]
        object2 = p_relevance[:, 1]
        p_TP_object1 = p_polarity_positive * object1 + p_polarity_negative * object2
        p_TP_object2 = p_polarity_negative * object1 + p_polarity_positive * object2
        p_TP_objects = torch.stack((p_TP_object1, p_TP_object2),
                                   dim=1).squeeze()

        #  gold label if we have.
        object1_gold = p_relevance_gold[:, 0]
        object2_gold = p_relevance_gold[:, 1]
        p_TP_object1_gold = p_polarity_positive_gold * object1_gold + p_polarity_negative_gold * object2_gold
        p_TP_object2_gold = p_polarity_negative_gold * object1_gold + p_polarity_positive_gold * object2_gold
        p_TP_objects_gold = torch.stack((p_TP_object1_gold, p_TP_object2_gold),
                                        dim=1).squeeze()

        try:
            assert torch.sum(padded_O1) != 0
            assert torch.sum(padded_O2) != 0
            assert torch.sum(padded_SP_o1) != 0
            assert torch.sum(padded_SP_o2) != 0
            assert torch.sum(padded_SP) != 0
            assert torch.sum(padded_TP) != 0
        except:
            pass
        loss_o1 = compute_loss(
            ps_object1, padded_O1,
            "find_object1") if object1_label is not None else 0.0
        loss_o2 = compute_loss(
            ps_object2, padded_O2,
            "find_object2") if object2_label is not None else 0.0
        loss_TP = compute_loss(pb_TP, padded_TP,
                               "find_TP") if TP_Back_label is not None else 0.0
        loss_SP = compute_loss(pb_SP, padded_SP,
                               "find_SP") if SP_Back_label is not None else 0.0
        loss_SP_o1 = compute_loss(
            ps_SP_object1, padded_SP_o1,
            "find_SP_object1") if SP_Object1_label is not None else 0.0
        loss_SP_o2 = compute_loss(
            ps_SP_object2, padded_SP_o2,
            "find_SP_object2") if SP_Object2_label is not None else 0.0
        loss_rel = compute_loss(
            normal_p_relevance_logits, sp_relevance,
            "relevance") if sp_relevance is not None else 0.0
        loss_pol = compute_loss(
            p_polarity_logits, sp_tp_polarity,
            "polarity") if sp_tp_polarity is not None else 0.0
        loss_on_TP = compute_loss(
            p_TP_objects, tp_relevance,
            "TP_relevance") if tp_relevance is not None else 0.0

        loss2_SP = compute_loss(
            pb_SP_gold, padded_SP,
            "find_SP") if SP_Back_label is not None else 0.0
        loss2_SP_o1 = compute_loss(
            ps_SP_object1_gold, padded_SP_o1,
            "find_SP_object1") if SP_Object1_label is not None else 0.0
        loss2_SP_o2 = compute_loss(
            ps_SP_object2_gold, padded_SP_o2,
            "find_SP_object2") if SP_Object2_label is not None else 0.0
        loss2_rel = compute_loss(
            normal_p_relevance_logits_gold, sp_relevance,
            "relevance") if sp_relevance is not None else 0.0
        loss2_pol = compute_loss(
            p_polarity_logits_gold, sp_tp_polarity,
            "polarity") if sp_tp_polarity is not None else 0.0
        loss2_on_TP = compute_loss(
            p_TP_objects, tp_relevance,
            "TP_relevance") if tp_relevance is not None else 0.0
        out = {
            "object1": (loss_o1).tolist(),
            "object2": (loss_o2).tolist(),
            "TP": (loss_TP).tolist(),
            "SP": (loss_SP).tolist(),
            "loss_SP_o1": (loss_SP_o1).tolist(),
            "loss_SP_o2": (loss_SP_o2).tolist(),
            "loss_rel": (loss_rel).tolist(),
            "loss_pol": (loss_pol).tolist(),
            "loss_on_TP": (loss_on_TP).tolist(),
        }

        # Loss function, play around it.
        loss = 0.05 * loss_o1 + 0.05 * loss_o2 + 0.05 * loss_SP + 0.05 * loss_TP + 0.05 * loss_SP_o1 + 0.05 * loss_SP_o2 + 0.2 * loss_pol + 0.2 * loss_rel + 0.3 * loss_on_TP

        #  The following part works as: returning the necessary numbers for predicting intermediate output for each modules.
        object1_ind = find_max_window(ps_object1,
                                      padded_mask_HS,
                                      offset=len_q + 3)
        object2_ind = find_max_window(ps_object2,
                                      padded_mask_HS,
                                      offset=len_q + 3)
        TP_ind = find_max_window(pb_TP,
                                 padded_mask_HB,
                                 offset=bs_seperator_index + 1)
        try:
            SP_object1_ind = find_max_window(ps_SP_object1,
                                             padded_mask_HS,
                                             offset=len_q + 3)
            SP_object2_ind = find_max_window(ps_SP_object2,
                                             padded_mask_HS,
                                             offset=len_q + 3)
            SP_ind = find_max_window(pb_SP,
                                     padded_mask_HB,
                                     offset=bs_seperator_index + 1)
        except:
            SP_object1_ind = [1, 1]
            SP_object2_ind = [1, 1]
            SP_ind = [1, 1]
        predict = {
            "p_o1": ps_object1.tolist(),
            "p_o2": ps_object2.tolist(),
            "p_TP": pb_TP.tolist(),
            "p_SP": pb_SP.tolist(),
            "p_sp_o1": ps_SP_object1.tolist(),
            "p_sp_o2": ps_SP_object2.tolist(),
            "object1": object1_ind,
            "object2": object2_ind,
            "TP": TP_ind,
            "SP": SP_ind,
            "SP_o1": SP_object1_ind,
            "SP_o2": SP_object2_ind,
            "relevance": p_relevance,
            "polarity": p_polarity,
            "tp_relevance": p_TP_objects,
        }
        output = [loss, out, predict]
        return output
    def forward(self,
                document,
                query=None,
                label=None,
                metadata=None) -> Dict[str, Any]:
        generator_dict = self._generator(document)
        mask = util.get_text_field_mask(document)
        assert "a" in generator_dict
        assert "b" in generator_dict

        a, b = generator_dict['a'], generator_dict['b']
        a = a.clamp(1e-6, 100.)  # extreme values could result in NaNs
        b = b.clamp(1e-6, 100.)  # extreme values could result in NaNs

        output_dict = {}

        sampler = HardKuma([a, b],
                           support=[
                               self.support[0].to(a.device),
                               self.support[1].to(b.device)
                           ])
        generator_dict['predicted_rationale'] = (sampler.mean() >
                                                 0.5).long() * mask

        if self.prediction_mode or not self.training:
            if self._rationale_extractor is None:
                sample_z = (sampler.mean() > 0.5).long() * mask
            else:
                prob_z = sampler.mean()
                sample_z = self._rationale_extractor.extract_rationale(
                    prob_z, metadata, as_one_hot=True)
                output_dict[
                    "rationale"] = self._rationale_extractor.extract_rationale(
                        prob_z, metadata, as_one_hot=False)
                sample_z = torch.Tensor(sample_z).to(prob_z.device).float()
        else:
            sample_z = sampler.sample()

        sample_z = sample_z * mask

        wordpiece_to_token = document['bert']['wordpiece-to-token']
        wtt0 = torch.where(wordpiece_to_token == -1,
                           torch.tensor([0]).to(wordpiece_to_token.device),
                           wordpiece_to_token)
        wordpiece_sample = util.batched_index_select(sample_z.unsqueeze(-1),
                                                     wtt0)
        wordpiece_sample[wordpiece_to_token.unsqueeze(-1) == -1] = 1.0

        def scale_embeddings(module, input, output):
            output = output * wordpiece_sample
            return output

        hook = self._encoder._embedding_layer.register_forward_hook(
            scale_embeddings)

        encoder_dict = self._encoder(
            document=document,
            query=query,
            label=label,
            metadata=metadata,
        )

        hook.remove()

        loss = 0.0

        if label is not None:
            assert "loss" in encoder_dict

            base_loss = F.cross_entropy(encoder_dict["logits"], label)  # (B,)

            lasso = ((1 - sampler.pdf(0.)) * mask).sum(1)
            lengths = mask.sum(1)

            sparsity_loss = lasso / (lengths + 1e-9) - self._desired_length
            sparsity_loss = sparsity_loss.mean()

            self._loss_tracks["_lasso_loss"](sparsity_loss.item())

            # # moving average of the constraint
            # self.sparsity_ma = self.lagrange_alpha * self.sparsity_ma + (1 - self.lagrange_alpha) * sparsity_loss.item()

            # # update lambda
            # self.lambda0 = self.lambda0 * torch.exp(self.lagrange_lr * self.sparsity_ma.detach())

            self._loss_tracks["_base_loss"](base_loss.item())
            # self._loss_tracks["_fused_lasso_loss"](self.lambda0.item())

            # loss += base_loss + min(max(self.lambda0.detach().item(), 0.01), 1.0) * sparsity_loss
            loss += base_loss + self._reg_loss_lambda * sparsity_loss

        output_dict["probs"] = encoder_dict["probs"]
        output_dict["predicted_labels"] = encoder_dict["predicted_labels"]

        output_dict["loss"] = loss
        output_dict["gold_labels"] = label
        output_dict["metadata"] = metadata

        output_dict["predicted_rationale"] = generator_dict[
            "predicted_rationale"]

        self._loss_tracks["_rat_length"](util.masked_mean(
            generator_dict["predicted_rationale"], mask, dim=-1).mean().item())

        self._call_metrics(output_dict)

        return output_dict
示例#23
0
文件: slstm.py 项目: Shuailong/SPM
    def forward(
            self,  # type: ignore
            inputs: torch.FloatTensor,
            mask: torch.FloatTensor):
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``
            A tensor of shape (batch_size, seq_len, hidden_size)
        mask : ``torch.FloatTensor``
            A tensor of shape (batch_size, seq_len)
        Returns
        -------
        An output dictionary consisting of:
        hiddens: ``torch.FloatTensor``
            A tensor of shape (batch_size, seq_len, hidden_size)
        """
        batch_size, _, hidden_size = inputs.size()

        # filters for attention
        mask = mask.unsqueeze(-1)

        ############################################################################
        # Init states
        ############################################################################
        # randomly initialize the states
        hidden = torch.rand_like(inputs) - 0.5
        cell = torch.rand_like(inputs) - 0.5

        global_hidden = masked_mean(hidden, mask, dim=1)
        global_cell = masked_mean(cell, mask, dim=1)

        for _ in range(self.num_layers):
            #############################
            # update global node states #
            #############################
            hidden_avg = masked_mean(hidden, mask, dim=1)

            projected_input = self.g_input_linearity(global_hidden)
            projected_hiddens = self.g_hidden_linearity(hidden)
            projected_avg = self.g_avg_linearity(hidden_avg)

            input_gate = torch.sigmoid(self.layer_norms[0](
                projected_input[:, 0 * hidden_size:1 * hidden_size] +
                projected_avg[:, 0 * hidden_size:1 * hidden_size]))
            hidden_gates = torch.sigmoid(self.layer_norms[1](
                projected_input[:, 1 * hidden_size:2 *
                                hidden_size].unsqueeze(1).expand_as(hidden) +
                projected_hiddens))
            output_gate = torch.sigmoid(self.layer_norms[2](
                projected_input[:, 2 * hidden_size:3 * hidden_size] +
                projected_avg[:, 1 * hidden_size:2 * hidden_size]))

            masked_hidden_gates = hidden_gates.masked_fill((1 - mask).byte(),
                                                           -1e32)
            all_gates = torch.cat(
                [input_gate.unsqueeze(1), masked_hidden_gates], dim=1)
            gates_normalized = torch.nn.functional.softmax(all_gates, dim=1)

            input_gate_normalized = gates_normalized[:, 0, :]
            hidden_gates_normalized = gates_normalized[:, 1:, :]

            # new global states
            global_cell = (hidden_gates_normalized * cell).sum(1) + \
                global_cell * input_gate_normalized
            global_hidden = output_gate * torch.tanh(global_cell)

            #############################
            # update hidden node states #
            #############################

            # Note: add <bos> and <eos> before hand in case that the valid words are omitted!
            hidden_l = torch.cat([
                hidden.new_zeros(batch_size, 1, hidden_size), hidden[:, :-1, :]
            ],
                                 dim=1)
            hidden_r = torch.cat([
                hidden[:, 1:, :],
                hidden.new_zeros(batch_size, 1, hidden_size)
            ],
                                 dim=1)
            cell_l = torch.cat(
                [cell.new_zeros(batch_size, 1, hidden_size), cell[:, :-1, :]],
                dim=1)
            cell_r = torch.cat(
                [cell[:, 1:, :],
                 cell.new_zeros(batch_size, 1, hidden_size)],
                dim=1)

            # concat with neighbors
            contexts = torch.cat([hidden_l, hidden_r], dim=-1)

            projected_contexts = self.h_context_linearity(contexts)
            projected_current = self.h_current_linearity(hidden)
            projected_input = self.h_input_linearity(inputs)
            projected_global = self.h_global_linearity(global_hidden)

            gates = []
            for offset in range(6):
                gates.append(
                    torch.sigmoid(
                        self.layer_norms[offset + 3]
                        (projected_contexts[..., offset * hidden_size:
                                            (offset + 1) * hidden_size] +
                         projected_current[..., offset * hidden_size:
                                           (offset + 1) * hidden_size] +
                         projected_input[..., offset * hidden_size:
                                         (offset + 1) * hidden_size] +
                         projected_global[
                             ..., offset * hidden_size:(offset + 1) *
                             hidden_size].unsqueeze(1).expand_as(inputs))))
            memory_init = torch.tanh(self.layer_norms[-1](
                projected_contexts[..., 6 * hidden_size:7 * hidden_size] +
                projected_current[..., 6 * hidden_size:7 * hidden_size] +
                projected_input[..., 6 * hidden_size:7 * hidden_size] +
                projected_global[..., 6 * hidden_size:7 *
                                 hidden_size].unsqueeze(1).expand_as(inputs)))

            # gate: batch x seq_len x hidden_size
            gates_normalized = F.softmax(torch.stack(gates[:-1]), dim=0)
            input_gate = gates_normalized[0, ...]
            left_gate = gates_normalized[1, ...]
            right_gate = gates_normalized[2, ...]
            forget_gate = gates_normalized[3, ...]
            global_gate = gates_normalized[4, ...]
            output_gate = gates[-1]

            cell = left_gate * cell_l +\
                right_gate * cell_r +\
                forget_gate * cell +\
                input_gate * memory_init +\
                global_gate * global_cell.unsqueeze(1).expand_as(global_gate)

            hidden = output_gate * torch.tanh(cell)
            hidden = hidden * mask
            cell = cell * mask

        return hidden
示例#24
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            store_metrics: bool = True,
            valid_output_mask: torch.LongTensor = None,
            sent_targets: torch.Tensor = None,
            stance: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        store_metrics : bool
            If true, stores metrics (if applicable) within model metric tracker.
            If false, returns resulting metrics immediately, without updating the model metric tracker.
        valid_output_mask: ``torch.LongTensor``, optional
            The locations for a valid answer. Used to limit the model's output space.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        # Debate: Conditioning on whose turn it is (A/B)
        if not self.is_judge:
            turn_film_params = self._turn_film_gen(
                stance.to(final_merged_passage).unsqueeze(1))
            turn_gammas, turn_betas = torch.split(
                turn_film_params, self._modeling_layer.get_input_dim(), dim=-1)
            final_merged_passage_mask = (
                final_merged_passage !=
                0).float()  # NOTE: Using heuristic to get mask
            final_merged_passage = self._film(
                final_merged_passage, 1. + turn_gammas,
                turn_betas) * final_merged_passage_mask
        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input_full = torch.cat(
            [final_merged_passage, modeled_passage], dim=-1)
        span_start_input = self._dropout(span_start_input_full)
        if not self.is_judge:
            value_head_input = span_start_input_full.detach(
            ) if self._detach_value_head else span_start_input_full
            # Shape: (batch_size)
            tokenwise_values = self._value_head(value_head_input).squeeze(-1)
            value, value_loc = util.replace_masked_values(
                tokenwise_values, passage_mask, -1e7).max(-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        valid_output_mask = passage_mask if valid_output_mask is None else valid_output_mask
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits,
                                               valid_output_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits,
                                             valid_output_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       valid_output_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     valid_output_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention":
            passage_question_attention,
            "span_start_logits":
            span_start_logits,
            "span_start_probs":
            span_start_probs,
            "span_end_logits":
            span_end_logits,
            "span_end_probs":
            span_end_probs,
            "best_span":
            best_span,
            "value":
            value if not self.is_judge else None,
            "prob":
            torch.tensor([
                span_start_probs[i, span_start[i]]
                if span_start[i] < span_start_probs.size(1) else 0.
                for i in range(batch_size)
            ]) if self.is_judge else None,  # prob(true answer)
            "prob_dist":
            span_start_probs,
        }

        # Compute the loss for training.
        if (span_start is not None) and self.is_judge:
            span_start[span_start >= passage_mask.size(
                1)] = -100  # NB: Hacky. Don't add to loss if span not in input
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, valid_output_mask),
                span_start.squeeze(-1))
            if store_metrics:
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(-1))
            span_end[span_end >= passage_mask.size(
                1)] = -100  # NB: Hacky. Don't add to loss if span not in input
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, valid_output_mask),
                span_end.squeeze(-1))
            if store_metrics:
                self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
                self._span_accuracy(best_span,
                                    torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
        elif not self.is_judge:  # Debate SL
            if self.reward_method == 'sl':  # sent_targets should be a vector of target indices
                output_dict["loss"] = nll_loss(
                    util.masked_log_softmax(span_start_logits,
                                            valid_output_mask),
                    sent_targets.squeeze(-1))
                if store_metrics:
                    self._span_start_accuracy(span_start_logits,
                                              sent_targets.squeeze(-1))
            elif self.reward_method.startswith('sl-sents'):
                # sent_targets should be a matrix of target values (non-zero only in EOS indices)
                sent_targets = util.replace_masked_values(
                    sent_targets, valid_output_mask, -1e7)
                output_dict["loss"] = util.masked_mean(
                    ((span_start_logits - sent_targets)**2), valid_output_mask,
                    1)
                if store_metrics:
                    self._span_start_accuracy(span_start_logits,
                                              sent_targets.max(-1)[1])

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        batch_ems = []
        batch_f1s = []
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    sample_squad_metrics = SquadEmAndF1()
                    sample_squad_metrics(best_span_string, answer_texts)
                    sample_em, sample_f1 = sample_squad_metrics.get_metric(
                        reset=True)
                    batch_ems.append(sample_em)
                    batch_f1s.append(sample_f1)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['em'] = torch.tensor(batch_ems)
            output_dict['f1'] = torch.tensor(batch_f1s)
        return output_dict
示例#25
0
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> None:

        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(
            max_batch_span_width,
            util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(
            raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor,
                                                    span_indices,
                                                    flat_span_indices)

        # Shape: (batch_size, num_spans, embedding_dim)
        # span_embeddings = util.masked_max(span_embeddings, span_mask.unsqueeze(-1), dim=2)
        span_embeddings = util.masked_mean(span_embeddings,
                                           span_mask.unsqueeze(-1),
                                           dim=2)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            span_width_embeddings = self._span_width_embedding(
                span_widths.squeeze(-1))
            span_embeddings = torch.cat(
                [span_embeddings, span_width_embeddings], -1)

        return span_embeddings
示例#26
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            choices_list: Dict[str, torch.LongTensor],
            choice_kb: Dict[str, torch.LongTensor],
            answer_text: Dict[str, torch.LongTensor],
            fact: Dict[str, torch.LongTensor],
            answer_spans: torch.IntTensor,
            relations: torch.IntTensor = None,
            relation_label: torch.IntTensor = None,
            answer_id: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # B X C X Ct X D
        embedded_choice, choice_mask = get_embedding(choices_list, 1,
                                                     self._text_field_embedder,
                                                     self._encoder,
                                                     self._var_dropout)
        # B X C X D
        # agg_choice, agg_choice_mask = get_agg_rep(embedded_choice, choice_mask, 1, self._encoder, self._aggregate)
        num_choices = embedded_choice.size()[1]
        batch_size = embedded_choice.size()[0]
        # B X Qt X D
        embedded_question, question_mask = get_embedding(
            question, 0, self._text_field_embedder, self._encoder,
            self._var_dropout)
        # B X D
        agg_question, agg_question_mask = get_agg_rep(embedded_question,
                                                      question_mask, 0,
                                                      self._encoder,
                                                      self._aggregate)

        # B X Ft X D
        embedded_fact, fact_mask = get_embedding(fact, 0,
                                                 self._text_field_embedder,
                                                 self._encoder,
                                                 self._var_dropout)
        # B X D
        agg_fact, agg_fact_mask = get_agg_rep(embedded_fact, fact_mask, 0,
                                              self._encoder, self._aggregate)

        # ==============================================
        # Interaction between fact and question
        # ==============================================
        # B x Ft x Qt
        fact_question_att = self._attention(embedded_fact, embedded_question)
        fact_question_mask = self.add_dimension(question_mask, 1,
                                                fact_question_att.shape[1])
        masked_fact_question_att = replace_masked_values(
            fact_question_att, fact_question_mask, -1e7)
        # B X Ft
        fact_question_att_max = masked_fact_question_att.max(
            dim=-1)[0].squeeze(-1)
        fact_question_att_softmax = masked_softmax(fact_question_att_max,
                                                   fact_mask)
        # B X D
        fact_question_att_rep = weighted_sum(embedded_fact,
                                             fact_question_att_softmax)
        # B*C X D
        cmerged_fact_question_att_rep = self.merge_dimensions(
            self.add_dimension(fact_question_att_rep, 1, num_choices))

        # ==============================================
        # Interaction between fact and answer choices
        # ==============================================

        # B*C X Ft X D
        cmerged_embedded_fact = self.merge_dimensions(
            self.add_dimension(embedded_fact, 1, num_choices))
        cmerged_fact_mask = self.merge_dimensions(
            self.add_dimension(fact_mask, 1, num_choices))

        # B*C X Ct X D
        cmerged_embedded_choice = self.merge_dimensions(embedded_choice)
        cmerged_choice_mask = self.merge_dimensions(choice_mask)

        # B*C X Ft X Ct
        cmerged_fact_choice_att = self._attention(cmerged_embedded_fact,
                                                  cmerged_embedded_choice)
        cmerged_fact_choice_mask = self.add_dimension(
            cmerged_choice_mask, 1, cmerged_fact_choice_att.shape[1])
        masked_cmerged_fact_choice_att = replace_masked_values(
            cmerged_fact_choice_att, cmerged_fact_choice_mask, -1e7)

        # B*C X Ft
        cmerged_fact_choice_att_max = masked_cmerged_fact_choice_att.max(
            dim=-1)[0].squeeze(-1)
        cmerged_fact_choice_att_softmax = masked_softmax(
            cmerged_fact_choice_att_max, cmerged_fact_mask)

        # B*C X D
        cmerged_fact_choice_att_rep = weighted_sum(
            cmerged_embedded_fact, cmerged_fact_choice_att_softmax)

        # ==============================================
        # Combined fact + choice + question + span rep
        # ==============================================
        if not self._ignore_spans and not self._ignore_ann:
            # B X A
            per_span_mask = (answer_spans >= 0).long()[:, :, 0]
            # B X A X D
            per_span_rep = self._span_extractor(embedded_fact, answer_spans,
                                                fact_mask, per_span_mask)
            # expanded_span_mask = per_span_mask.unsqueeze(-1).expand_as(per_span_rep)

            # B X D
            answer_span_rep = per_span_rep[:, 0, :]

            # B*C X D
            cmerged_span_rep = self.merge_dimensions(
                self.add_dimension(answer_span_rep, 1, num_choices))
            fact_choice_question_rep = (cmerged_fact_choice_att_rep +
                                        cmerged_fact_question_att_rep +
                                        cmerged_span_rep) / 3

        else:
            fact_choice_question_rep = (cmerged_fact_choice_att_rep +
                                        cmerged_fact_question_att_rep) / 2
        # B*C X D
        cmerged_fact_rep = masked_mean(
            cmerged_embedded_fact,
            cmerged_fact_mask.unsqueeze(-1).expand_as(cmerged_embedded_fact),
            1)
        # B*C X D
        fact_question_combined_rep = combine_tensors(
            self._coverage_combination,
            [fact_choice_question_rep, cmerged_fact_rep])

        # B X C X  D
        new_size = [batch_size, num_choices, -1]
        fact_question_combined_rep = fact_question_combined_rep.contiguous(
        ).view(*new_size)
        # B X C
        coverage_score = self._coverage_ff(fact_question_combined_rep).squeeze(
            -1)
        logger.info("coverage_score" + str(coverage_score.shape))

        # ==============================================
        # Interaction between spans+choices and KB
        # ==============================================

        # B X C X K X Kt x D
        embedded_choice_kb, choice_kb_mask = get_embedding(
            choice_kb, 2, self._text_field_embedder, self._encoder,
            self._var_dropout)
        num_kb = embedded_choice_kb.size()[2]

        # B X A X At X D
        embedded_answer, answer_mask = get_embedding(answer_text, 1,
                                                     self._text_field_embedder,
                                                     self._encoder,
                                                     self._var_dropout)
        # B X At X D
        embedded_answer = embedded_answer[:, 0, :, :]
        answer_mask = answer_mask[:, 0, :]

        # B*C*K X Kt X D
        ckmerged_embedded_choice_kb = self.merge_dimensions(
            self.merge_dimensions(embedded_choice_kb))
        ckmerged_choice_kb_mask = self.merge_dimensions(
            self.merge_dimensions(choice_kb_mask))

        # B*C X At X D
        cmerged_embedded_answer = self.merge_dimensions(
            self.add_dimension(embedded_answer, 1, num_choices))
        cmerged_answer_mask = self.merge_dimensions(
            self.add_dimension(answer_mask, 1, num_choices))
        # B*C*K X At X D
        ckmerged_embedded_answer = self.merge_dimensions(
            self.add_dimension(cmerged_embedded_answer, 1, num_kb))
        ckmerged_answer_mask = self.merge_dimensions(
            self.add_dimension(cmerged_answer_mask, 1, num_kb))
        # B*C*K X Ct X D
        ckmerged_embedded_choice = self.merge_dimensions(
            self.add_dimension(cmerged_embedded_choice, 1, num_kb))
        ckmerged_choice_mask = self.merge_dimensions(
            self.add_dimension(cmerged_choice_mask, 1, num_kb))
        logger.info("ckmerged_choice_mask" + str(ckmerged_choice_mask.shape))

        # == KB rep based on answer span ==
        if self._ignore_ann:
            # B*C*K X Ft X D
            ckmerged_embedded_fact = self.merge_dimensions(
                self.add_dimension(cmerged_embedded_fact, 1, num_kb))
            ckmerged_fact_mask = self.merge_dimensions(
                self.add_dimension(cmerged_fact_mask, 1, num_kb))
            # B*C*K X Kt x Ft
            ckmerged_kb_fact_att = self._attention(ckmerged_embedded_choice_kb,
                                                   ckmerged_embedded_fact)
            ckmerged_kb_fact_mask = self.add_dimension(
                ckmerged_fact_mask, 1, ckmerged_kb_fact_att.shape[1])
            masked_ckmerged_kb_fact_att = replace_masked_values(
                ckmerged_kb_fact_att, ckmerged_kb_fact_mask, -1e7)

            # B*C*K X Kt
            ckmerged_kb_answer_att_max = masked_ckmerged_kb_fact_att.max(
                dim=-1)[0].squeeze(-1)
        else:
            # B*C*K X Kt x At
            ckmerged_kb_answer_att = self._attention(
                ckmerged_embedded_choice_kb, ckmerged_embedded_answer)
            ckmerged_kb_answer_mask = self.add_dimension(
                ckmerged_answer_mask, 1, ckmerged_kb_answer_att.shape[1])
            masked_ckmerged_kb_answer_att = replace_masked_values(
                ckmerged_kb_answer_att, ckmerged_kb_answer_mask, -1e7)

            # B*C*K X Kt
            ckmerged_kb_answer_att_max = masked_ckmerged_kb_answer_att.max(
                dim=-1)[0].squeeze(-1)

        ckmerged_kb_answer_att_softmax = masked_softmax(
            ckmerged_kb_answer_att_max, ckmerged_choice_kb_mask)

        # B*C*K X D
        kb_answer_att_rep = weighted_sum(ckmerged_embedded_choice_kb,
                                         ckmerged_kb_answer_att_softmax)

        # == KB rep based on answer choice ==
        # B*C*K X Kt x Ct
        ckmerged_kb_choice_att = self._attention(ckmerged_embedded_choice_kb,
                                                 ckmerged_embedded_choice)
        ckmerged_kb_choice_mask = self.add_dimension(
            ckmerged_choice_mask, 1, ckmerged_kb_choice_att.shape[1])
        masked_ckmerged_kb_choice_att = replace_masked_values(
            ckmerged_kb_choice_att, ckmerged_kb_choice_mask, -1e7)

        # B*C*K X Kt
        ckmerged_kb_choice_att_max = masked_ckmerged_kb_choice_att.max(
            dim=-1)[0].squeeze(-1)
        ckmerged_kb_choice_att_softmax = masked_softmax(
            ckmerged_kb_choice_att_max, ckmerged_choice_kb_mask)

        # B*C*K X D
        kb_choice_att_rep = weighted_sum(ckmerged_embedded_choice_kb,
                                         ckmerged_kb_choice_att_softmax)

        # B*C*K X D
        answer_choice_kb_combined_rep = combine_tensors(
            self._answer_choice_combination,
            [kb_answer_att_rep, kb_choice_att_rep])
        logger.info("answer_choice_kb_combined_rep" +
                    str(answer_choice_kb_combined_rep.shape))

        # ==============================================
        # Relation Predictions
        # ==============================================

        # B*C*K x R
        choice_kb_relation_rep = self._relation_predictor(
            answer_choice_kb_combined_rep)
        new_choice_kb_size = [batch_size * num_choices, num_kb, -1]
        # B*C*K
        merged_choice_kb_mask = (torch.sum(ckmerged_choice_kb_mask, dim=-1) >
                                 0).float()
        if self._num_relations and not self._ignore_ann:
            if self._relation_projector:
                choice_kb_relation_pred = self._relation_projector(
                    choice_kb_relation_rep)
            else:
                choice_kb_relation_pred = choice_kb_relation_rep

            # Aggregate the predictions
            # B*C*K
            choice_kb_relation_mask = self.add_dimension(
                merged_choice_kb_mask, -1, choice_kb_relation_pred.shape[-1])
            choice_kb_relation_pred_masked = replace_masked_values(
                choice_kb_relation_pred, choice_kb_relation_mask, -1e7)
            # B*C X K X R
            relation_pred_perkb = choice_kb_relation_pred_masked.contiguous(
            ).view(*new_choice_kb_size)
            # B*C X R
            relation_pred_max = relation_pred_perkb.max(dim=1)[0].squeeze(1)

            # B X C X R
            choice_relation_size = [batch_size, num_choices, -1]
            relation_label_logits = relation_pred_max.contiguous().view(
                *choice_relation_size)
            relation_label_probs = softmax(relation_label_logits, dim=-1)
            # B X C
            add_relation_predictions(self.vocab, relation_label_probs,
                                     metadata)
            # B X C X K X R
            choice_kb_relation_size = [batch_size, num_choices, num_kb, -1]
            relation_predictions = choice_kb_relation_rep.contiguous().view(
                *choice_kb_relation_size)
            add_tuple_predictions(relation_predictions, metadata)
            logger.info("relation_predictions" +
                        str(relation_predictions.shape))
        else:
            relation_label_logits = None
            relation_label_probs = None

        if not self._ignore_relns:
            # B X C X D
            expanded_size = [batch_size, num_choices, -1]
            # Aggregate the relation representation
            if self._relation_projector or self._num_relations == 0 or self._ignore_ann:
                # B*C X K X D
                relation_rep_perkb = choice_kb_relation_rep.contiguous().view(
                    *new_choice_kb_size)
                # B*C*K X D
                merged_relation_rep_mask = self.add_dimension(
                    merged_choice_kb_mask, -1, relation_rep_perkb.shape[-1])
                # B*C X K X D
                relation_rep_perkb_mask = merged_relation_rep_mask.contiguous(
                ).view(*relation_rep_perkb.size())
                # B*C X D
                agg_relation_rep = masked_mean(relation_rep_perkb,
                                               relation_rep_perkb_mask,
                                               dim=1)
                # B X C X D
                expanded_relation_rep = agg_relation_rep.contiguous().view(
                    *expanded_size)
            else:
                expanded_relation_rep = relation_label_logits

            expanded_question_rep = agg_question.unsqueeze(1).expand(
                expanded_size)
            expanded_fact_rep = agg_fact.unsqueeze(1).expand(expanded_size)
            question_fact_rep = combine_tensors(
                self._combination, [expanded_question_rep, expanded_fact_rep])

            relation_score_rep = torch.cat(
                [question_fact_rep, expanded_relation_rep], dim=-1)
            relation_score = self._reln_ff(relation_score_rep).squeeze(-1)
            choice_label_logits = (coverage_score + relation_score) / 2
        else:
            choice_label_logits = coverage_score
        logger.info("choice_label_logits" + str(choice_label_logits.shape))

        choice_label_probs = softmax(choice_label_logits, dim=-1)
        output_dict = {
            "label_logits": choice_label_logits,
            "label_probs": choice_label_probs,
            "metadata": metadata
        }
        if relation_label_logits is not None:
            output_dict["relation_label_logits"] = relation_label_logits
            output_dict["relation_label_probs"] = relation_label_probs

        if answer_id is not None or relation_label is not None:
            self.compute_loss_and_accuracy(answer_id, relation_label,
                                           relation_label_logits,
                                           choice_label_logits, output_dict)
        return output_dict
示例#27
0
    def forward(
        self,
        context_1: torch.Tensor,
        mask_1: torch.Tensor,
        context_2: torch.Tensor,
        mask_2: torch.Tensor,
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral
        matching functions between them in one direction.

        Parameters
        ----------
        context_1 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence.
        mask_1 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len1), indicating which
            positions in the first sentence are padding (0) and which are not (1).
        context_2 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence.
        mask_2 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len2), indicating which
            positions in the second sentence are padding (0) and which are not (1).

        Returns
        -------
        A tuple of matching vectors for the two sentences. Each of which is a list of
        matching vectors of shape (batch, seq_len, num_perspectives or 1)
        """
        assert (not mask_2.requires_grad) and (not mask_1.requires_grad)
        assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim

        # (batch,)
        len_1 = get_lengths_from_binary_sequence_mask(mask_1)
        len_2 = get_lengths_from_binary_sequence_mask(mask_2)

        # (batch, seq_len*)
        mask_1, mask_2 = mask_1.float(), mask_2.float()

        # explicitly set masked weights to zero
        # (batch_size, seq_len*, hidden_dim)
        context_1 = context_1 * mask_1.unsqueeze(-1)
        context_2 = context_2 * mask_2.unsqueeze(-1)

        # array to keep the matching vectors for the two sentences
        matching_vector_1: List[torch.Tensor] = []
        matching_vector_2: List[torch.Tensor] = []

        # Step 0. unweighted cosine
        # First calculate the cosine similarities between each forward
        # (or backward) contextual embedding and every forward (or backward)
        # contextual embedding of the other sentence.

        # (batch, seq_len1, seq_len2)
        cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2),
                                         context_2.unsqueeze(-3),
                                         dim=3)

        # (batch, seq_len*, 1)
        cosine_max_1 = masked_max(cosine_sim,
                                  mask_2.unsqueeze(-2),
                                  dim=2,
                                  keepdim=True)
        cosine_mean_1 = masked_mean(cosine_sim,
                                    mask_2.unsqueeze(-2),
                                    dim=2,
                                    keepdim=True)
        cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1),
                                  mask_1.unsqueeze(-2),
                                  dim=2,
                                  keepdim=True)
        cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1),
                                    mask_1.unsqueeze(-2),
                                    dim=2,
                                    keepdim=True)

        matching_vector_1.extend([cosine_max_1, cosine_mean_1])
        matching_vector_2.extend([cosine_max_2, cosine_mean_2])

        # Step 1. Full-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with the last time step of the forward (or backward)
        # contextual embedding of the other sentence
        if self.with_full_match:

            # (batch, 1, hidden_dim)
            if self.is_forward:
                # (batch, 1, hidden_dim)
                last_position_1 = (len_1 - 1).clamp(min=0)
                last_position_1 = last_position_1.view(-1, 1, 1).expand(
                    -1, 1, self.hidden_dim)
                last_position_2 = (len_2 - 1).clamp(min=0)
                last_position_2 = last_position_2.view(-1, 1, 1).expand(
                    -1, 1, self.hidden_dim)

                context_1_last = context_1.gather(1, last_position_1)
                context_2_last = context_2.gather(1, last_position_2)
            else:
                context_1_last = context_1[:, 0:1, :]
                context_2_last = context_2[:, 0:1, :]

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_full = multi_perspective_match(
                context_1, context_2_last, self.full_match_weights)
            matching_vector_2_full = multi_perspective_match(
                context_2, context_1_last, self.full_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_full)
            matching_vector_2.extend(matching_vector_2_full)

        # Step 2. Maxpooling-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with every time step of the forward (or backward)
        # contextual embedding of the other sentence, and only the max value of each
        # dimension is retained.
        if self.with_maxpool_match:
            # (batch, seq_len1, seq_len2, num_perspectives)
            matching_vector_max = multi_perspective_match_pairwise(
                context_1, context_2, self.maxpool_match_weights)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_max = masked_max(
                matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            matching_vector_1_mean = masked_mean(
                matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            matching_vector_2_max = masked_max(
                matching_vector_max.permute(0, 2, 1, 3),
                mask_1.unsqueeze(-2).unsqueeze(-1),
                dim=2)
            matching_vector_2_mean = masked_mean(
                matching_vector_max.permute(0, 2, 1, 3),
                mask_1.unsqueeze(-2).unsqueeze(-1),
                dim=2)

            matching_vector_1.extend(
                [matching_vector_1_max, matching_vector_1_mean])
            matching_vector_2.extend(
                [matching_vector_2_max, matching_vector_2_mean])

        # Step 3. Attentive-Matching
        # Each forward (or backward) similarity is taken as the weight
        # of the forward (or backward) contextual embedding, and calculate an
        # attentive vector for the sentence by weighted summing all its
        # contextual embeddings.
        # Finally match each forward (or backward) contextual embedding
        # with its corresponding attentive vector.

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1)

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1)

        if self.with_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1))
            att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1))

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_mean = multi_perspective_match(
                context_1, att_mean_2, self.attentive_match_weights)
            matching_vector_2_att_mean = multi_perspective_match(
                context_2, att_mean_1, self.attentive_match_weights_reversed)
            matching_vector_1.extend(matching_vector_1_att_mean)
            matching_vector_2.extend(matching_vector_2_att_mean)

        # Step 4. Max-Attentive-Matching
        # Pick the contextual embeddings with the highest cosine similarity as the attentive
        # vector, and match each forward (or backward) contextual embedding with its
        # corresponding attentive vector.
        if self.with_max_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_max_2 = masked_max(att_2,
                                   mask_2.unsqueeze(-2).unsqueeze(-1),
                                   dim=2)
            att_max_1 = masked_max(att_1.permute(0, 2, 1, 3),
                                   mask_1.unsqueeze(-2).unsqueeze(-1),
                                   dim=2)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_max = multi_perspective_match(
                context_1, att_max_2, self.max_attentive_match_weights)
            matching_vector_2_att_max = multi_perspective_match(
                context_2, att_max_1,
                self.max_attentive_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_att_max)
            matching_vector_2.extend(matching_vector_2_att_max)

        return matching_vector_1, matching_vector_2
示例#28
0
    def forward(
        self,
        messages: Dict[str, torch.Tensor],
        # (batch_size, n_turns, n_facts, n_words)
        facts: Dict[str, torch.Tensor],
        # (batch_size, n_turns)
        senders: torch.Tensor,
        # (batch_size, n_turns, n_acts)
        dialog_acts: torch.Tensor,
        # (batch_size, n_turns)
        dialog_acts_mask: torch.Tensor,
        # (batch_size, n_entities)
        known_entities: Dict[str, torch.Tensor],
        # (batch_size, 1)
        focus_entity: Dict[str, torch.Tensor],
        # (batch_size, n_turns, n_facts)
        fact_labels: Optional[torch.Tensor] = None,
        # (batch_size, n_turns, 2)
        likes: Optional[torch.Tensor] = None,
        metadata: Optional[Dict] = None,
    ):
        output = {}
        # Take care of the easy stuff first

        # (batch_size, n_entities)
        known_entities_mask = get_text_field_mask(known_entities)

        # (batch_size, n_turns, sender_emb_size)
        sender_emb = self._sender_emb(senders)

        known_emb = self._mention_embedder(known_entities)
        # TODO: This could instead of averaged, be attended
        known_vec = self._known_net(
            masked_mean(known_emb, known_entities_mask.unsqueeze(-1), dim=1))
        # There is always exactly one entity
        focus_emb = self._focus_net(
            self._mention_embedder(focus_entity)[:, 0, :])

        if self._use_bert:
            # (batch_size, n_turns, n_words, emb_dim)
            context, utter_mask = self._bert_encoder(messages)
            context = self._dropout(context)
        else:
            # (batch_size, n_turns)
            # This is the mask since not all dialogs have same number
            # of turns
            utter_mask = get_text_field_mask(messages)

            # (batch_size, n_turns, n_words)
            # Mask since not all utterances have same number of words
            # Wrapping dim skips over n_messages dim
            text_mask = get_text_field_mask(messages, num_wrapping_dims=1)
            # (batch_size, n_turns, n_words, emb_dim)
            embed = self._dropout(self._utter_embedder(messages))
            # (batch_size, n_turns, hidden_dim)
            context = self._dist_utter_context(embed, text_mask)

        # (batch_size, n_turns, act_emb_size)
        act_emb = self._act_embedder(dialog_acts.float())
        act_emb = self._clamp_dialog_acts(act_emb)

        # (batch_size, n_turns, hidden_dim + known_dim + focus_dim + sender_dim + act_dim)
        n_turns = context.shape[1]
        full_context = torch.cat(
            (
                context,
                sender_emb,
                act_emb,
                focus_emb[:, None, :].repeat_interleave(n_turns, 1),
                known_vec[:, None, :].repeat_interleave(n_turns, 1),
            ),
            dim=-1,
        )

        # (batch_size, n_turns, hidden_dim)
        # This assumes dialog_context does not peek into future
        dialog_context = self._dialog_context(full_context, utter_mask)

        # shift context one right, pad with zeros at front
        # This makes it so that utter_t is paired with context_t-1
        # which is what we want
        # This is useful in a few different places, so compute it here once
        shape = dialog_context.shape
        shifted_context = torch.cat(
            (
                dialog_context.new_zeros([shape[0], 1, shape[2]]),
                dialog_context[:, :-1, :],
            ),
            dim=1,
        )
        has_loss = False

        if self._disable_dialog_acts:
            da_loss = 0
            policy_loss = 0
        else:
            # Dialog act per utter loss
            has_loss = True
            da_loss = self._compute_da_loss(
                output,
                context,
                shifted_context,
                utter_mask,
                dialog_acts,
                dialog_acts_mask,
            )
            # Policy loss
            policy_loss = self._compute_policy_loss(output, shifted_context,
                                                    utter_mask, dialog_acts,
                                                    dialog_acts_mask)

        if self._disable_facts:
            # If facts are disabled, don't output anything related
            # to them
            fact_loss = 0
        else:
            if self._use_bert:
                # (batch_size, n_turns, n_words, emb_dim)
                fact_repr, fact_mask = self._bert_encoder(facts)
                fact_repr = self._dropout(fact_repr)
                fact_mask[:, ::2] = 0
            else:
                # (batch_size, n_turns, n_facts)
                # Wrapping dim skips over n_messages
                fact_mask = get_text_field_mask(facts, num_wrapping_dims=1)
                # In addition to masking padded facts, also explicitly mask
                # user turns just in case
                fact_mask[:, ::2] = 0

                # (batch_size, n_turns, n_facts, n_words)
                # Wrapping dim skips over n_turns and n_facts
                fact_text_mask = get_text_field_mask(facts,
                                                     num_wrapping_dims=2)
                # (batch_size, n_turns, n_facts, n_words, emb_dim)
                # Share encoder with utter encoder
                # Again, stupid dimensions
                fact_embed = self._dropout(self._utter_embedder(facts))
                shape = fact_embed.shape
                word_dim = shape[-2]
                emb_dim = shape[-1]
                reshaped_facts = fact_embed.view(-1, word_dim, emb_dim)
                reshaped_fact_text_mask = fact_text_mask.view(-1, word_dim)
                reshaped_fact_repr = self._utter_context(
                    reshaped_facts, reshaped_fact_text_mask)
                # No more emb dimension or word/seq dim
                fact_repr = reshaped_fact_repr.view(shape[:-2] + (-1, ))

            fact_logits = self._fact_ranker(
                shifted_context,
                fact_repr,
            )
            output["fact_logits"] = fact_logits
            if fact_labels is not None:
                has_loss = True
                fact_loss = self._compute_fact_loss(fact_logits, fact_labels,
                                                    fact_mask)
                self._fact_loss_metric(fact_loss.item())
                self._fact_mrr(fact_logits, fact_labels, mask=fact_mask)
            else:
                fact_loss = 0

        if self._disable_likes:
            like_loss = 0
        else:
            has_loss = True
            # (batch_size, n_turns, 2)
            like_logits = self._like_classifier(dialog_context)
            output["like_logits"] = like_logits

            # There are several masks here to get the loss/metrics correct
            # - utter_mask: mask out positions that do not have an utterance
            # - user_mask: mask out positions that have a user utterances
            #              since their turns are never liked
            # Using new_ones() preserves the type of the tensor
            user_mask = utter_mask.new_ones(utter_mask.shape)

            # Since the user is always even, this masks out user positions
            user_mask[:, ::2] = 0
            final_mask = utter_mask * user_mask
            masked_likes = likes * final_mask
            if likes is not None:
                has_loss = True
                like_loss = sequence_cross_entropy_with_logits(
                    like_logits, masked_likes, final_mask)
                self._like_accuracy(like_logits, masked_likes, final_mask)
                self._like_loss_metric(like_loss.item())
            else:
                like_loss = 0

        if has_loss:
            output["loss"] = (self._fact_loss_weight * fact_loss + like_loss +
                              da_loss + policy_loss)

        return output
    def forward(self,
                document,
                kept_tokens,
                rationale=None,
                label=None,
                metadata=None) -> Dict[str, Any]:
        generator_dict = self._generator(document, rationale)
        mask = util.get_text_field_mask(document)
        assert "prob_z" in generator_dict

        prob_z = generator_dict["prob_z"]
        assert len(prob_z.shape) == 2

        prob_z = kept_tokens.float() + prob_z * (1 - kept_tokens)
        sampler = D.bernoulli.Bernoulli(probs=prob_z)

        sample_z = sampler.sample() * mask.float()
        encoder_dict = self._encoder(sample_z=sample_z,
                                     label=label,
                                     metadata=metadata)

        loss = 0.0

        if label is not None:
            assert "loss" in encoder_dict

            loss_sample = encoder_dict["loss"]  # (B,)
            loss += loss_sample.mean()

            lasso_loss = util.masked_mean(sample_z, mask, dim=-1)  # (B,)

            masked_sum = mask[:, :-1].sum(-1).clamp(1e-5)
            diff = (sample_z[:, 1:] - sample_z[:, :-1]).abs()
            masked_diff = (diff * mask[:, :-1]).sum(-1)
            fused_lasso_loss = masked_diff / masked_sum

            self._loss_tracks["lasso_loss"](lasso_loss.mean().item())
            self._loss_tracks["fused_lasso_loss"](
                fused_lasso_loss.mean().item())
            self._loss_tracks["base_loss"](loss_sample.mean().item())

            log_prob_z = torch.log(
                1 + torch.exp(sampler.log_prob(sample_z)))  # (B, L)
            log_prob_z_sum = (mask * log_prob_z).mean(-1)  # (B,)

            generator_loss = (
                loss_sample.detach() + lasso_loss * self._reg_loss_lambda +
                fused_lasso_loss *
                (self._reg_loss_mu * self._reg_loss_lambda)) * log_prob_z_sum

            loss += self._reinforce_loss_weight * generator_loss.mean()

        output_dict = generator_dict

        loss += self._rationale_supervision_loss_weight * generator_dict.get(
            "rationale_supervision_loss", 0.0)

        output_dict["logits"] = encoder_dict["logits"]
        output_dict['probs'] = encoder_dict['probs']
        output_dict["class_probs"] = encoder_dict["class_probs"]
        output_dict["predicted_labels"] = encoder_dict["predicted_labels"]
        output_dict["gold_labels"] = encoder_dict["gold_labels"]

        output_dict["loss"] = loss
        output_dict["metadata"] = metadata
        output_dict["mask"] = mask

        self._call_metrics(output_dict)

        return output_dict
示例#30
0
    def forward(self,
                document,
                rationale=None,
                kept_tokens=None,
                query=None,
                label=None,
                metadata=None) -> Dict[str, Any]:
        generator_dict = self._generator(document, rationale)
        mask = util.get_text_field_mask(document)
        assert "probs" in generator_dict

        prob_z = generator_dict["probs"]
        assert len(prob_z.shape) == 2

        output_dict = {}

        sampler = D.bernoulli.Bernoulli(probs=prob_z)
        if self.prediction_mode or not self.training:
            if self._rationale_extractor is None:
                sample_z = generator_dict['predicted_rationale'].float()
            else:
                sample_z = self._rationale_extractor.extract_rationale(
                    prob_z, metadata, as_one_hot=True)
                output_dict[
                    "rationale"] = self._rationale_extractor.extract_rationale(
                        prob_z, metadata, as_one_hot=False)
                sample_z = torch.Tensor(sample_z).to(prob_z.device).float()
        else:
            sample_z = sampler.sample()

        sample_z = sample_z * mask
        reduced_document = self.regenerate_tokens(metadata, sample_z)
        encoder_dict = self._encoder(
            document=reduced_document,
            query=query,
            label=label,
            metadata=metadata,
        )

        loss = generator_dict['loss']

        if label is not None:
            assert "loss" in encoder_dict

            log_prob_z = sampler.log_prob(sample_z)  # (B, L)
            log_prob_z_sum = (mask * log_prob_z).sum(-1)  # (B,)
            loss_sample = F.cross_entropy(encoder_dict["logits"],
                                          label,
                                          reduction="none")  # (B,)

            sparsity = util.masked_mean(sample_z, mask, dim=-1)
            censored_lasso_loss = F.relu(sparsity - self._desired_length)

            diff = (sample_z[:, 1:] - sample_z[:, :-1]).abs()
            mask_last = mask[:, :-1]
            fused_lasso_loss = diff.sum(-1) / mask_last.sum(-1)

            self._loss_tracks["_lasso_loss"](sparsity.mean().item())
            self._loss_tracks["_fused_lasso_loss"](
                fused_lasso_loss.mean().item())
            self._loss_tracks["_base_loss"](loss_sample.mean().item())

            base_loss = loss_sample
            generator_loss = (
                loss_sample.detach() + censored_lasso_loss *
                self._reg_loss_lambda + fused_lasso_loss *
                (self._reg_loss_mu * self._reg_loss_lambda)) * log_prob_z_sum

            loss += (base_loss + generator_loss).mean()

        output_dict["probs"] = encoder_dict["probs"]
        output_dict["predicted_labels"] = encoder_dict["predicted_labels"]

        output_dict["loss"] = loss
        output_dict["gold_labels"] = label
        output_dict["metadata"] = metadata

        output_dict["prob_z"] = generator_dict["prob_z"]
        output_dict["predicted_rationale"] = generator_dict[
            "predicted_rationale"]

        self._loss_tracks["_rat_length"](util.masked_mean(
            generator_dict["predicted_rationale"], mask, dim=-1).mean().item())

        self._call_metrics(output_dict)

        return output_dict
    def forward(self,
                document,
                query=None,
                label=None,
                metadata=None,
                rationale=None) -> Dict[str, Any]:
        # pylint: disable=arguments-differ

        generator_dict = self._generator(document, query, label)
        mask = generator_dict["mask"]

        assert "a" in generator_dict
        assert "b" in generator_dict

        a, b = generator_dict["a"], generator_dict["b"]
        a = a.clamp(1e-6, 100.0)  # extreme values could result in NaNs
        b = b.clamp(1e-6, 100.0)  # extreme values could result in NaNs

        output_dict = {}

        sampler = HardKuma([a, b],
                           support=[
                               self.support[0].to(a.device),
                               self.support[1].to(b.device)
                           ])
        generator_dict["predicted_rationale"] = (sampler.mean() >
                                                 0.5).long() * mask

        if self.prediction_mode or not self.training:
            if self._rationale_extractor is None:
                # We constrain rationales to be 0 or 1 strictly. See Pruthi et al
                # for pathologies when this is not the case.
                sample_z = (sampler.mean() > 0.5).long() * mask
            else:
                prob_z = sampler.mean()
                sample_z = self._rationale_extractor.extract_rationale(
                    prob_z, document, as_one_hot=True)
                output_dict[
                    "rationale"] = self._rationale_extractor.extract_rationale(
                        prob_z, document, as_one_hot=False)
                sample_z = torch.Tensor(sample_z).to(prob_z.device).float()
        else:
            sample_z = sampler.sample()

        sample_z = sample_z * mask

        # Because BERT is BERT
        wordpiece_to_token = generator_dict["wordpiece-to-token"]
        wtt0 = torch.where(wordpiece_to_token == -1,
                           torch.tensor([0]).to(wordpiece_to_token.device),
                           wordpiece_to_token)
        wordpiece_sample = util.batched_index_select(sample_z.unsqueeze(-1),
                                                     wtt0)
        wordpiece_sample[wordpiece_to_token.unsqueeze(-1) == -1] = 1.0

        def scale_embeddings(module, input, output):
            output = output * wordpiece_sample
            return output

        hook = self._encoder.embedding_layers[0].register_forward_hook(
            scale_embeddings)

        encoder_dict = self._encoder(
            document=document,
            query=query,
            label=label,
            metadata=metadata,
        )

        hook.remove()

        loss = 0.0

        if label is not None:
            assert "loss" in encoder_dict

            base_loss = F.cross_entropy(encoder_dict["logits"], label)  # (B,)
            loss += base_loss

            pdf0 = sampler.pdf(0.0) * mask
            pdf_nonzero = (1 - pdf0) * mask
            lasso_loss = pdf_nonzero.sum(1)
            lengths = mask.sum(1)

            lasso_loss = lasso_loss / (lengths + 1e-9)
            lasso_loss = lasso_loss.mean()

            c0_hat = F.relu(lasso_loss - self._desired_length)
            if self.training:
                self.c0_ma = self.lagrange_alpha * self.c0_ma + (
                    1 - self.lagrange_alpha) * c0_hat.item()

            c0 = c0_hat + (self.c0_ma.detach() - c0_hat.detach())

            if self.training:
                self.lambda0 = self.lambda0 * torch.exp(
                    self.lagrange_lr * c0.detach())
                self.lambda0 = self.lambda0.clamp(self.lambda_min,
                                                  self.lambda_max)

            self._loss_tracks["_lasso_loss"](lasso_loss.item())
            self._loss_tracks["_base_loss"](base_loss.item())
            self._loss_tracks["_lambda0"](self.lambda0[0].item())
            self._loss_tracks["_c0_ma"](self.c0_ma[0].item())
            self._loss_tracks["_c0"](c0_hat.item())

            regulariser_loss = (self.lambda0.detach() * c0)[0]
            loss += regulariser_loss

        output_dict["probs"] = encoder_dict["probs"]
        output_dict["predicted_labels"] = encoder_dict["predicted_labels"]

        output_dict["loss"] = loss
        output_dict["gold_labels"] = label
        output_dict["metadata"] = metadata

        output_dict["predicted_rationale"] = generator_dict[
            "predicted_rationale"]

        self._loss_tracks["_rat_length"](
            util.masked_mean(generator_dict["predicted_rationale"],
                             mask == 1,
                             dim=-1).mean().item())

        self._call_metrics(output_dict)

        return output_dict
示例#32
0
    def forward(self,
                context_1: torch.Tensor,
                mask_1: torch.Tensor,
                context_2: torch.Tensor,
                mask_2: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        # pylint: disable=arguments-differ
        """
        Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral
        matching functions between them in one direction.

        Parameters
        ----------
        context_1 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence.
        mask_1 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len1), indicating which
            positions in the first sentence are padding (0) and which are not (1).
        context_2 : ``torch.Tensor``
            Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence.
        mask_2 : ``torch.Tensor``
            Binary Tensor of shape (batch_size, seq_len2), indicating which
            positions in the second sentence are padding (0) and which are not (1).

        Returns
        -------
        A tuple of matching vectors for the two sentences. Each of which is a list of
        matching vectors of shape (batch, seq_len, num_perspectives or 1)
        """
        assert (not mask_2.requires_grad) and (not mask_1.requires_grad)
        assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim

        # (batch,)
        len_1 = get_lengths_from_binary_sequence_mask(mask_1)
        len_2 = get_lengths_from_binary_sequence_mask(mask_2)

        # (batch, seq_len*)
        mask_1, mask_2 = mask_1.float(), mask_2.float()

        # explicitly set masked weights to zero
        # (batch_size, seq_len*, hidden_dim)
        context_1 = context_1 * mask_1.unsqueeze(-1)
        context_2 = context_2 * mask_2.unsqueeze(-1)

        # array to keep the matching vectors for the two sentences
        matching_vector_1: List[torch.Tensor] = []
        matching_vector_2: List[torch.Tensor] = []

        # Step 0. unweighted cosine
        # First calculate the cosine similarities between each forward
        # (or backward) contextual embedding and every forward (or backward)
        # contextual embedding of the other sentence.

        # (batch, seq_len1, seq_len2)
        cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2), context_2.unsqueeze(-3), dim=3)

        # (batch, seq_len*, 1)
        cosine_max_1 = masked_max(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True)
        cosine_mean_1 = masked_mean(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True)
        cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True)
        cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True)

        matching_vector_1.extend([cosine_max_1, cosine_mean_1])
        matching_vector_2.extend([cosine_max_2, cosine_mean_2])

        # Step 1. Full-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with the last time step of the forward (or backward)
        # contextual embedding of the other sentence
        if self.with_full_match:

            # (batch, 1, hidden_dim)
            if self.is_forward:
                # (batch, 1, hidden_dim)
                last_position_1 = (len_1 - 1).clamp(min=0)
                last_position_1 = last_position_1.view(-1, 1, 1).expand(-1, 1, self.hidden_dim)
                last_position_2 = (len_2 - 1).clamp(min=0)
                last_position_2 = last_position_2.view(-1, 1, 1).expand(-1, 1, self.hidden_dim)

                context_1_last = context_1.gather(1, last_position_1)
                context_2_last = context_2.gather(1, last_position_2)
            else:
                context_1_last = context_1[:, 0:1, :]
                context_2_last = context_2[:, 0:1, :]

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_full = multi_perspective_match(context_1,
                                                             context_2_last,
                                                             self.full_match_weights)
            matching_vector_2_full = multi_perspective_match(context_2,
                                                             context_1_last,
                                                             self.full_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_full)
            matching_vector_2.extend(matching_vector_2_full)

        # Step 2. Maxpooling-Matching
        # Each time step of forward (or backward) contextual embedding of one sentence
        # is compared with every time step of the forward (or backward)
        # contextual embedding of the other sentence, and only the max value of each
        # dimension is retained.
        if self.with_maxpool_match:
            # (batch, seq_len1, seq_len2, num_perspectives)
            matching_vector_max = multi_perspective_match_pairwise(context_1,
                                                                   context_2,
                                                                   self.maxpool_match_weights)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_max = masked_max(matching_vector_max,
                                               mask_2.unsqueeze(-2).unsqueeze(-1),
                                               dim=2)
            matching_vector_1_mean = masked_mean(matching_vector_max,
                                                 mask_2.unsqueeze(-2).unsqueeze(-1),
                                                 dim=2)
            matching_vector_2_max = masked_max(matching_vector_max.permute(0, 2, 1, 3),
                                               mask_1.unsqueeze(-2).unsqueeze(-1),
                                               dim=2)
            matching_vector_2_mean = masked_mean(matching_vector_max.permute(0, 2, 1, 3),
                                                 mask_1.unsqueeze(-2).unsqueeze(-1),
                                                 dim=2)

            matching_vector_1.extend([matching_vector_1_max, matching_vector_1_mean])
            matching_vector_2.extend([matching_vector_2_max, matching_vector_2_mean])


        # Step 3. Attentive-Matching
        # Each forward (or backward) similarity is taken as the weight
        # of the forward (or backward) contextual embedding, and calculate an
        # attentive vector for the sentence by weighted summing all its
        # contextual embeddings.
        # Finally match each forward (or backward) contextual embedding
        # with its corresponding attentive vector.

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1)

        # (batch, seq_len1, seq_len2, hidden_dim)
        att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1)

        if self.with_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1))
            att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1))

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_mean = multi_perspective_match(context_1,
                                                                 att_mean_2,
                                                                 self.attentive_match_weights)
            matching_vector_2_att_mean = multi_perspective_match(context_2,
                                                                 att_mean_1,
                                                                 self.attentive_match_weights_reversed)
            matching_vector_1.extend(matching_vector_1_att_mean)
            matching_vector_2.extend(matching_vector_2_att_mean)

        # Step 4. Max-Attentive-Matching
        # Pick the contextual embeddings with the highest cosine similarity as the attentive
        # vector, and match each forward (or backward) contextual embedding with its
        # corresponding attentive vector.
        if self.with_max_attentive_match:
            # (batch, seq_len*, hidden_dim)
            att_max_2 = masked_max(att_2, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2)
            att_max_1 = masked_max(att_1.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2)

            # (batch, seq_len*, num_perspectives)
            matching_vector_1_att_max = multi_perspective_match(context_1,
                                                                att_max_2,
                                                                self.max_attentive_match_weights)
            matching_vector_2_att_max = multi_perspective_match(context_2,
                                                                att_max_1,
                                                                self.max_attentive_match_weights_reversed)

            matching_vector_1.extend(matching_vector_1_att_max)
            matching_vector_2.extend(matching_vector_2_att_max)

        return matching_vector_1, matching_vector_2