Beispiel #1
0
    def _compute_attention(
            self,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state.
        Parameters
        ----------
        decoder_hidden_state : ``torch.LongTensor``
            A tensor of shape ``(batch_size, decoder_output_dim)``, which contains the current decoder hidden state to be used
            as the 'query' to the attention computation
            during the last time step.
        encoder_outputs : ``torch.LongTensor``
            A tensor of shape ``(batch_size, max_input_sequence_length, encoder_output_dim)``, which contains all the
            encoder hidden states of the source tokens, i.e., the 'keys' to the attention computation
        encoder_mask : ``torch.LongTensor``
            A tensor of shape (batch_size, max_input_sequence_length), which contains the mask of the encoded input.
            We want to avoid computing an attention score for positions of the source with zero-values (remember not all
            input sentences have the same length)

        Returns
        -------
        torch.Tensor
            A tensor of shape (batch_size, encoder_output_dim) that contains the attended encoder outputs (aka context vector),
            i.e., we have ``applied`` the attention scores on the encoder hidden states.

        Notes
        -----
            Don't forget to apply the final softmax over the **masked** encoder outputs!
        """

        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # Main body of attention weights computation here
        attention_scores = encoder_outputs.bmm(
            decoder_hidden_state.unsqueeze(-1)).squeeze(-1)
        masked_attention_scores = masked_softmax(attention_scores,
                                                 encoder_outputs_mask)
        attended_output = util.weighted_sum(encoder_outputs,
                                            masked_attention_scores)

        # masked_softmax()
        return attended_output, masked_attention_scores
Beispiel #2
0
    def _prepare_decode_step_input(
            self,
            input_indices: torch.LongTensor,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        """
        Given the input indices for the current timestep of the decoder, and all the encoder
        outputs, compute the input at the current timestep.  Note: This method is agnostic to
        whether the indices are gold indices or the predictions made by the decoder at the last
        timestep. So, this can be used even if we're doing some kind of scheduled sampling.

        If we're not using attention, the output of this method is just an embedding of the input
        indices.  If we are, the output will be a concatentation of the embedding and an attended
        average of the encoder inputs.

        Parameters
        ----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
            Output of from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self._target_embedder(input_indices)
        if self._attention_function:
            # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
            # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
            # complain.
            encoder_outputs_mask = encoder_outputs_mask.float()
            # (batch_size, input_sequence_length)
            input_weights = self._decoder_attention(decoder_hidden_state,
                                                    encoder_outputs,
                                                    encoder_outputs_mask)
            # (batch_size, encoder_output_dim)
            attended_input = weighted_sum(encoder_outputs, input_weights)
            # (batch_size, encoder_output_dim + target_embedding_dim)
            return torch.cat((attended_input, embedded_input), -1)
        else:
            return embedded_input
Beispiel #3
0
    def _prepare_attended_input(self,
                                decoder_hidden_state: torch.LongTensor = None,
                                encoder_outputs: torch.LongTensor = None,
                                encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(
                decoder_hidden_state, encoder_outputs, encoder_outputs_mask)

        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input
Beispiel #4
0
    def _prepare_decode_step_input(self,
                                   input_indices: torch.LongTensor,
                                   decoder_hidden_state: torch.LongTensor = None,
                                   encoder_outputs: torch.LongTensor = None,
                                   encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        """
        Given the input indices for the current timestep of the decoder, and all the encoder
        outputs, compute the input at the current timestep.  Note: This method is agnostic to
        whether the indices are gold indices or the predictions made by the decoder at the last
        timestep. So, this can be used even if we're doing some kind of scheduled sampling.

        If we're not using attention, the output of this method is just an embedding of the input
        indices.  If we are, the output will be a concatentation of the embedding and an attended
        average of the encoder inputs.

        Parameters
        ----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
            Output of from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self._target_embedder(input_indices)
        if self._attention_function:
            # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
            # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
            # complain.
            encoder_outputs_mask = encoder_outputs_mask.float()
            # (batch_size, input_sequence_length)
            input_weights = self._decoder_attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
            # (batch_size, encoder_output_dim)
            attended_input = weighted_sum(encoder_outputs, input_weights)
            # (batch_size, encoder_output_dim + target_embedding_dim)
            return torch.cat((attended_input, embedded_input), -1)
        else:
            return embedded_input
    def prepare_decode_step_input(
            self, input_indices: torch.LongTensor,
            decoder_hidden: torch.LongTensor,
            encoder_outputs: torch.LongTensor,
            encoder_outputs_mask: torch.LongTensor) -> torch.LongTensor:
        """
        Prepares the current timestep input for the decoder.

        By default, simply embeds and returns the input. If using attention, the default attention
        (BiLinearAttention) is applied to attend on the step input given the encoded source
        sequence and the previous hidden state.

        Parameters:
        -----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden : torch.LongTensor, optional (not needed if no attention)
            Output from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self.target_embedder(input_indices)
        if self.apply_attention:
            if isinstance(decoder_hidden, tuple):
                decoder_hidden = decoder_hidden[0]
            # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
            # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
            # complain.
            encoder_outputs_mask = encoder_outputs_mask.float()
            # (batch_size, input_sequence_length)
            input_weights = self.decoder_attention_function(
                decoder_hidden[-1], encoder_outputs, encoder_outputs_mask)
            # (batch_size, encoder_output_dim)
            attended_input = util.weighted_sum(encoder_outputs, input_weights)
            # (batch_size, encoder_output_dim + target_embedding_dim)
            return torch.cat((attended_input, embedded_input), -1)
        else:
            return embedded_input
Beispiel #6
0
    def _get_loss_context_words(
            slef,
            embedded_context: torch.FloatTensor,
            embedded_pivot_phrase: torch.FloatTensor,
            context_words_mask: torch.LongTensor,
            batch_average: bool = True) -> torch.FloatTensor:
        """
        Returns
        -------
        A torch.FloatTensor representing the cross entropy loss.
        If ``batch_average == True``, the returned loss is a scalar.
        If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

        """
        # (batch_size, num_context_words, emb_size) x (batch_size, emb_size, 1) -> (batch_size, num_context_words, 1)
        loss_context_words = torch.bmm(
            embedded_context,
            embedded_pivot_phrase.unsqueeze(1).transpose(1, 2))
        # (batch_size, num_context_words)
        loss_context_words = loss_context_words.squeeze()
        # (batch_size, num_context_words)
        loss_context_words = loss_context_words.sigmoid().clamp(
            min=1e-20).log()
        # (batch_size, num_context_words)
        loss_context_words = loss_context_words * context_words_mask.float()
        # (batch_size,);
        # here we add 1e-13 to omit division by zero;
        # however numerator is zero anyway due to applying mask above
        per_batch_loss = loss_context_words.sum(1) / (
            context_words_mask.sum(1).float() + 1e-13)

        # make sure there are no infs, that rarely happens
        # per_batch_loss = per_batch_loss.clamp(min=1e-18, max=1e18)

        if batch_average:
            # (scalar)
            num_non_empty_sequences = (
                (context_words_mask.sum(1) > 0).float().sum() + 1e-13)
            # (scalar)
            return per_batch_loss.sum() / num_non_empty_sequences

        return per_batch_loss
Beispiel #7
0
 def _decode_step_output(
         self,
         decoder_hidden_state: torch.LongTensor = None,
         encoder_outputs: torch.LongTensor = None,
         encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
     # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
     # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
     # complain.
     encoder_outputs_mask = encoder_outputs_mask.float()
     # (batch_size, input_sequence_length)
     input_weights_e = self._decoder_attention(decoder_hidden_state,
                                               encoder_outputs,
                                               encoder_outputs_mask)
     input_weights_a = F.softmax(input_weights_e, dim=-1)
     # (batch_size, encoder_output_dim)
     attended_input = weighted_sum(encoder_outputs, input_weights_a)
     #H*_t = sum(h_i*at_i)
     # (batch_size, encoder_output_dim + decoder_hidden_dim)
     return input_weights_e, input_weights_a, torch.cat(
         (decoder_hidden_state, attended_input), -1)
Beispiel #8
0
    def _prepare_attended_input(
        self,
        decoder_hidden_state: torch.LongTensor = None,
        encoder_outputs: torch.LongTensor = None,
        encoder_outputs_mask: torch.LongTensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()
        # shape: (batch_size, max_input_sequence_length)
        input_logits = self._attention(decoder_hidden_state, encoder_outputs,
                                       encoder_outputs_mask)
        # the attention mechanism returns the logits that are necessary for attention supervision loss,
        # so we normalize it here
        input_weights = masked_softmax(input_logits, encoder_outputs_mask)
        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input, input_logits
Beispiel #9
0
    def edge_existence(self,encoded_text: torch.Tensor, mask : torch.LongTensor) -> torch.Tensor:
        """
        Computes edge existence scores for a batch of sentences.

        Parameters
        ----------
        encoded_text : torch.Tensor, required
            The input sentence, with artificial root node (head sentinel) added in the beginning of
            shape (batch_size, sequence length, encoding dim)
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.

        Returns
        -------
        attended_arcs: torch.Tensor
            The edge existence scores in a tensor of shape (batch_size, sequence_length, sequence_length). The mask is taken into account.
        """
        float_mask = mask.float()

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        bs,sl,arc_dim = head_arc_representation.size()

        #now repeat the token representations to form a matrix:
        #shape (batch_size, sequence_length, sequence_length, arc_representation_dim)
        heads = head_arc_representation.repeat(1,sl,1).reshape(bs,sl,sl,arc_dim) #heads in one direction
        deps = child_arc_representation.repeat(1, sl, 1).reshape(bs, sl, sl, arc_dim).transpose(1,2) #deps in the other direction

        # shape (batch_size, sequence_length, sequence_length, arc_representation_dim)
        combined = self.activation(heads + deps) #now the feedforward layer that takes every pair of vectors for tokens is complete.
        #combined now represents the activations in the hidden layer of the MLP.
        edge_scores = self.arc_out_layer(combined).squeeze(3) #last dimension is now 1, remove it

        #mask out stuff for padded tokens:
        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        edge_scores = edge_scores + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)
        return edge_scores
Beispiel #10
0
    def forward(self, data: Dict, labels: torch.LongTensor):
        ge_list = []
        for i in range(self.num_graph):
            # get representation
            ge = self.emb(data['emb_ind'][i])
            if self.method == 'bow':
                # average using stat
                # SHAPE: (batch_size, emb_size)
                ge = (data['stat'][i].unsqueeze(-1) * ge).sum(1)  # TODO: add bias?
                #ge /= data['emb_ind'][i].ne(self.padding_ind).sum(-1).float().unsqueeze(-1)
                ge = F.relu(ge)
            else:
                if self.method != 'cosine':
                    ge = self.emb_proj(ge)
                if self.method == 'ggnn':
                    gnn = self.gnn.compute_node_representations(
                        initial_node_representation=ge, adjacency_lists=data['adj'][i])
                    # TODO: combine ggnn and emb
                    #ge = 0.5 * ge + 0.5 * gnn
                    ge = gnn
                # SHAPE: (batch_size, emb_size)
                ge = torch.index_select(ge, 0, data['prop_ind'][i])  # select property emb
            ge_list.append(ge)
        # match
        if self.match == 'concat':
            ge = torch.cat(ge_list, -1)
            ge = self.cla(ge)
        elif self.match == 'cosine':
            ge = self.cosine_cla(F.cosine_similarity(ge_list[0], ge_list[1]).unsqueeze(-1))

        if self.num_class == 1:
            # binary classification loss
            logits = torch.sigmoid(ge)
            labels = labels.float()
            loss = nn.BCELoss()(logits.squeeze(), labels)
        else:
            # cross-entropy loss
            logits = ge
            loss = nn.CrossEntropyLoss()(logits, labels)
        return logits, loss
Beispiel #11
0
    def _decode_step_output(
            self,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        encoder_outputs_mask = encoder_outputs_mask.float()

        #decribe("encoder_outputs_mask", encoder_outputs_mask)

        input_weights_e = self._decoder_attention(decoder_hidden_state,
                                                  encoder_outputs,
                                                  encoder_outputs_mask)

        #decribe("input_weights_e", input_weights_e)

        input_weights_a = masked_softmax(input_weights_e, encoder_outputs_mask)

        #decribe("input_weights_a", input_weights_a)
        attended_input = weighted_sum(encoder_outputs, input_weights_a)

        #decribe("attended_input", attended_input)
        return input_weights_a, torch.cat(
            (decoder_hidden_state, attended_input), -1)
Beispiel #12
0
    def forward(self,
                text: Dict[str, torch.Tensor],
                labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        embedded_text = self.text_field_embedder(text)
        mask = util.get_text_field_mask(text)
        encoded_text = self.encoder(embedded_text, mask)

        logits = self.classifier_feedforward(encoded_text)
        probabilities = torch.nn.functional.sigmoid(logits)

        output_dict = {'logits': logits, 'probabilities': probabilities}

        if labels is not None:
            loss = self.loss(logits + eps, labels.float())
            #loss = self.loss(logits, labels.squeeze(-1).long())
            output_dict['loss'] = loss

            predictions = (logits.data > 0.0).long()
            label_data = labels.squeeze(-1).data.long()
            self.f1(predictions, label_data)

        return output_dict
Beispiel #13
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        mask_p = util.get_text_field_mask(premise)
        mask_h = util.get_text_field_mask(hypothesis)

        embedded_p = self.text_field_embedder(premise)
        encoded_p = self.encoder(embedded_p, mask_p)

        embedded_h = self.text_field_embedder(hypothesis)
        encoded_h = self.encoder(embedded_h, mask_h)

        fc_p, fc_h = self.feedforward(encoded_p, encoded_h)

        distance = F.pairwise_distance(fc_p, fc_h)
        prediction = distance < (self.margin / 2.0)
        output_dict = {'distance': distance, "prediction": prediction}

        if label is not None:
            """
            Contrastive loss function.
            Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
            """
            y = label.float()
            l1 = y * torch.pow(distance, 2) / 2.0
            l2 = (1 - y) * torch.pow(
                torch.clamp(self.margin - distance, min=0.0), 2) / 2.0
            loss = torch.mean(l1 + l2)

            self.accuracy(prediction, label.byte())

            output_dict["loss"] = loss

        return output_dict
Beispiel #14
0
    def edge_existence(self, encoded_text: torch.Tensor,
                       mask: torch.LongTensor) -> torch.Tensor:
        """
        Computes edge existence scores for a batch of sentences.

        Parameters
        ----------
        encoded_text : torch.Tensor, required
            The input sentence, with artificial root node (head sentinel) added in the beginning of
            shape (batch_size, sequence length, encoding dim)

        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.

        Returns
        -------
        attended_arcs: torch.Tensor
            The edge existence scores in a tensor of shape (batch_size, sequence_length, sequence_length). The mask is taken into account.
        """
        float_mask = mask.float()

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)
        return attended_arcs
Beispiel #15
0
    def forward(
        self,
        word_ids: TextFieldTensors,
        entity_span: torch.LongTensor,
        labels: torch.LongTensor = None,
        entity_ids: torch.LongTensor = None,
        input_sentence: List[str] = None,
        **kwargs,
    ):
        feature_vector = self.feature_extractor(word_ids[self.text_field_key],
                                                entity_span, entity_ids)
        feature_vector = self.dropout(feature_vector)
        logits = self.classifier(feature_vector)

        output_dict = {
            "input": input_sentence,
            "prediction": torch.softmax(logits, dim=-1)
        }

        if labels is not None:
            output_dict["loss"] = self.criterion(logits, labels.float())
            output_dict["gold_label"] = labels
            self.f1_score(logits, labels)
        return output_dict
    def _decode_step_output(
            self,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        """
        Given all the encoder outputs, compute the input at the current timestep.  Note: This method is agnostic to
        whether the indices are gold indices or the predictions made by the decoder at the last
        timestep. So, this can be used even if we're doing some kind of scheduled sampling.

        Parameters
        ----------
        decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
            Output of from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
        # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
        # complain.
        #import pdb
        #pdb.set_trace()
        encoder_outputs_mask = encoder_outputs_mask.float()
        # (batch_size, input_sequence_length)
        input_weights_e = self._decoder_attention(decoder_hidden_state,
                                                  encoder_outputs,
                                                  encoder_outputs_mask)
        input_weights_a = F.softmax(input_weights_e, dim=-1)
        # (batch_size, encoder_output_dim)
        attended_input = weighted_sum(encoder_outputs, input_weights_a)
        #H*_t = sum(h_i*at_i)
        # (batch_size, encoder_output_dim + decoder_hidden_dim)
        return input_weights_e, input_weights_a, torch.cat(
            (decoder_hidden_state, attended_input), -1)
Beispiel #17
0
    def forward(self,  # type: ignore
                # words: Dict[str, torch.LongTensor],
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                pos_logits: torch.LongTensor = None,  # predicted
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, _, _ = encoded_text.size()

        pos_tags = None
        if pos_logits is not None and self.pos_tag_embedding is not None:
            # Embed the predicted POS tags and concatenate the embeddings to the input
            num_pos_classes = pos_logits.size(-1)
            pos_logits = pos_logits.view(-1, num_pos_classes)
            _, pos_tags = pos_logits.max(-1)

            pos_embed_size = self.pos_tag_embedding.get_output_dim()
            embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags))
            embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1)

        encoded_text = self.encoder(encoded_text, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
            "words": [meta["words"] for meta in metadata],
            # "pos": [meta["pos"] for meta in metadata]
        }

        return output_dict
    def _parse(
        self,
        encoded_text: torch.Tensor,
        mask: torch.LongTensor,
        head_tags: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor]:

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask,
            )
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask,
            )

        return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll
    def forward(
            self,
            sentences: Dict[str, torch.LongTensor],
            categories: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        """

        :param sentences: Tensor of word indexes [batch_size * sample_size * seq_length]
        :param categories:
        :return:
        """

        # exclude tensors which are larger then real amount of tokens
        # such as tensors ngram-tensors
        maskable_sentences = dict((key, val) for key, val in sentences.items()
                                  if '-ngram' not in key)

        # shape: (batch_size * sample_size * seq_length)
        mask = get_text_field_mask(maskable_sentences, num_wrapping_dims=1)

        batch_size, sample_size, seq_length = mask.shape

        flat_mask = mask.view(batch_size * sample_size, seq_length)

        # lengths = get_lengths_from_binary_sequence_mask(flat_mask)
        # sorted_mask, sorted_lengths, restoration_indices, permutation_index = sort_batch_by_length(flat_mask, lengths)

        # shape: ((batch_size * sample_size) * seq_length * embedding)
        embedded = self.text_embedder(sentences).view(batch_size * sample_size,
                                                      seq_length, -1)

        # shape: ((batch_size * sample_size) * seq_length * encoder_dim)
        encoder_outputs = self.encoder(embedded, flat_mask)

        # shape: ((batch_size * sample_size), encoder_output_dim)
        final_encoder_output = get_final_encoder_states(
            encoder_outputs, flat_mask, self.encoder.is_bidirectional())

        # shape: (batch_size * sample_size * encoder_output_dim)
        sentences_embedding = final_encoder_output.view(
            batch_size, sample_size, -1)

        # shape: ((batch_size * sample_size) * seq_length * embedding + encoder_dim)
        encoder_outputs = torch.cat([embedded, encoder_outputs], dim=-1)

        # shape: (batch_size, sample_size, seq_length, encoder_dim + embedding)
        encoder_outputs = encoder_outputs.view(batch_size, sample_size,
                                               seq_length, -1)

        mentions_embeddings, attention_weights = self.seq_combiner(
            encoder_outputs, mask, sentences_embedding)

        outputs = self._output_projection_layer(mentions_embeddings)

        if len(attention_weights) > 0:
            attention_weights = np.moveaxis(np.stack(attention_weights), 0, 1)
            attention_weights = np.split(attention_weights,
                                         len(attention_weights) / sample_size)

        result = {
            'predictions': torch.sigmoid(outputs),
            'attention_weights': attention_weights
        }

        if categories is not None:
            result['loss'] = self.loss(outputs, categories.float())
            # self.metrics['auc'](outputs.view(-1), categories.view(-1))
            # self.metrics['m-auc'](outputs, categories)
            self.metrics['f1']((outputs > 0.5).long(), categories)

        return result
    def _joint_likelihood(self, logits: torch.Tensor, tags: torch.Tensor,
                          mask: torch.LongTensor) -> torch.Tensor:
        """
        Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
        """
        batch_size, sequence_length, num_tags = logits.data.shape

        # Transpose batch size and sequence dimensions:
        logits = logits.transpose(0, 1).contiguous()
        mask = mask.float().transpose(0, 1).contiguous()
        tags = tags.transpose(0, 1).contiguous()

        # Start with the transition scores from start_tag to the first tag in each input
        score = self.start_transitions.index_select(0, tags[0])

        # Broadcast the transition scores to one per batch element
        broadcast_transitions = self.transitions.view(
            1, num_tags, num_tags).expand(batch_size, num_tags, num_tags)

        # Add up the scores for the observed transitions and all the inputs but the last
        for i in range(sequence_length - 1):
            # Each is shape (batch_size,)
            current_tag, next_tag = tags[i], tags[i + 1]

            # The scores for transitioning from current_tag to next_tag
            transition_score = (
                    broadcast_transitions
                    # Choose the current_tag-th row for each input
                    .gather(1, current_tag.view(batch_size, 1, 1).expand(batch_size, 1, num_tags))
                    # Squeeze down to (batch_size, num_tags)
                    .squeeze(1)
                    # Then choose the next_tag-th column for each of those
                    .gather(1, next_tag.view(batch_size, 1))
                    # And squeeze down to (batch_size,)
                    .squeeze(1)
            )

            # The score for using current_tag
            emit_score = logits[i].gather(1, current_tag.view(batch_size,
                                                              1)).squeeze(1)

            # Include transition score if next element is unmasked,
            # input_score if this element is unmasked.
            score = score + transition_score * mask[i +
                                                    1] + emit_score * mask[i]

        # Transition from last state to "stop" state. To start with, we need to find the last tag
        # for each instance.
        last_tag_index = mask.sum(0).long() - 1
        last_tags = tags.gather(
            0,
            last_tag_index.view(1, batch_size).expand(sequence_length,
                                                      batch_size))

        # Is (sequence_length, batch_size), but all the columns are the same, so take the first.
        last_tags = last_tags[0]

        # Compute score of transitioning to `stop_tag` from each "last tag".
        last_transition_score = self.end_transitions.index_select(0, last_tags)

        # Add the last input if it's not masked.
        last_inputs = logits[-1]  # (batch_size, num_tags)
        last_input_score = last_inputs.gather(1, last_tags.view(
            -1, 1))  # (batch_size, 1)
        # last_input_score = last_input_score.squeeze()                    # (batch_size,)
        last_input_score = last_input_score.squeeze(1)

        score = score + last_transition_score + last_input_score * mask[-1]

        return score
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            verb_indicator: torch.LongTensor,
            target_index: torch.LongTensor,
            span_starts: torch.LongTensor,
            span_ends: torch.LongTensor,
            span_mask: torch.LongTensor,
            constituents: torch.LongTensor = None,
            tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        bio : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
        tags: shape ``(batch_size, num_spans)``
        span_starts: shape ``(batch_size, num_spans)``
        span_ends: shape ``(batch_size, num_spans)``

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        self.batch += 1
        embedded_text_input = self.embedding_dropout(
            self.text_field_embedder(tokens))
        batch_size = embedded_text_input.size(0)
        text_mask = util.get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(
            verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat(
            [embedded_text_input, embedded_verb_indicator], -1)
        embedding_dim_with_binary_feature = embedded_text_with_verb_indicator.size(
        )[2]

        if self.stacked_encoder.get_input_dim(
        ) != embedding_dim_with_binary_feature:
            raise ConfigurationError(
                "The SRL model uses an indicator feature, which makes "
                "the embedding dimension one larger than the value "
                "specified. Therefore, the 'input_dim' of the stacked_encoder "
                "must be equal to total_embedding_dim + 1.")
        encoded_text = self.stacked_encoder(embedded_text_with_verb_indicator,
                                            text_mask)

        span_starts = F.relu(span_starts.float()).long().view(batch_size, -1)
        span_ends = F.relu(span_ends.float()).long().view(batch_size, -1)
        target_index = F.relu(target_index.float()).long().view(batch_size)
        # shape (batch_size, sequence_length * max_span_width, embedding_dim)
        span_embeddings = span_srl_util.compute_span_representations(
            self.max_span_width, encoded_text, target_index, span_starts,
            span_ends, self.span_width_embedding,
            self.span_direction_embedding, self.span_distance_embedding,
            self.span_distance_bin, self.head_scorer)
        span_scores = self.span_feedforward(span_embeddings)

        srl_logits = self.srl_arg_projection_layer(span_scores)
        constit_logits = self.constit_arg_projection_layer(span_scores)
        output_dict = {
            "srl_logits": srl_logits,
            "constit_logits": constit_logits,
            "mask": text_mask
        }

        tags = tags.view(batch_size, -1, self.max_span_width)
        constituents = constituents.view(batch_size, -1, self.max_span_width)

        # Viterbi decoding
        if not self.training or (self.training and not self.fast_mode):
            srl_prediction, srl_probabilities = self.semi_crf.viterbi_tags(
                srl_logits, text_mask)
            output_dict["srl_tags"] = srl_prediction
            output_dict["srl_tag_probabilities"] = srl_probabilities
            self.metrics["srl"](predictions=srl_prediction.view(
                batch_size, -1, self.max_span_width),
                                gold_labels=tags,
                                mask=text_mask)

            reshaped_constit_logits = constit_logits.view(
                -1, self.num_constit_tags)
            constit_probabilities = F.softmax(reshaped_constit_logits, dim=-1)
            constit_predictions = constit_probabilities.max(-1)[1]
            output_dict["constit_tags"] = constit_predictions
            output_dict["constit_probabilities"] = constit_probabilities

            constit_predictions = constit_predictions.view(
                batch_size, -1, self.max_span_width)
            self.metrics["constituents"](predictions=constit_predictions,
                                         gold_labels=constituents,
                                         mask=text_mask)

        # Loss computation
        if self.training or (not self.training and not self.fast_mode):
            if tags is not None:
                srl_log_likelihood, _ = self.semi_crf(srl_logits,
                                                      tags,
                                                      mask=text_mask)
                output_dict["srl_loss"] = -srl_log_likelihood
            if constituents is not None:
                # Flattening it out.
                constituents = constituents.view(batch_size, -1)
                constit_loss = util.sequence_cross_entropy_with_logits(
                    constit_logits, constituents, span_mask)
                output_dict["constit_loss"] = constit_loss
            if tags is not None and constituents is not None:
                if self.batch > self.cutoff_batch:
                    output_dict["loss"] = - srl_log_likelihood + self.mixing_ratio * \
                        constit_loss
                else:
                    output_dict["loss"] = -srl_log_likelihood
        if self.fast_mode and not self.training:
            output_dict["loss"] = Variable(torch.FloatTensor([0.00]))

        return output_dict
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1)
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP ``SpanField`` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1)

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1))
        exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
            raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, "
                             f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                             f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                             f"{sequence_lengths}.")

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts)
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends)
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            float_end_sentinel_mask = end_sentinel_mask.float()
            float_start_sentinel_mask = start_sentinel_mask.float()
            forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel
            backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \
                                        + float_end_sentinel_mask * self._end_sentinel

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(self._forward_combination,
                                             [forward_start_embeddings, forward_end_embeddings])
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(self._backward_combination,
                                              [backward_start_embeddings, backward_end_embeddings])
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.float().unsqueeze(-1)
        return span_embeddings
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            verb_span: torch.LongTensor,
            entity_span: torch.LongTensor,
            state_change_type_labels: torch.LongTensor = None,
            state_change_tags: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_span: torch.LongTensor, required.
            An integer ``SequenceLabelField`` representation of the position of the focus verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that pre-processing stage could not extract a verbal predicate.
        entity_span: torch.LongTensor, required.
            An integer ``SequenceLabelField`` representation of the position of the focus entity
            in the sentence. This should have shape (batch_size, num_tokens) 
        state_change_type_labels: torch.LongTensor, optional (default = None)
            A torch tensor representing the state change type class labels of shape ``(batch_size, 1)???
        state_change_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
            In the first implementation we focus only on state_change_types.

        Returns
        -------
        An output dictionary consisting of:
        type_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_state_change_types)`` representing
            a distribution of state change types per datapoint.
        tags_class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_state_change_types, num_tokens)`` representing
            a distribution of location tags per token in a sentence.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # Layer 1 = Word + Character embedding layer
        embedded_sentence = self.text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        # Layer 2 = Add positional bit to encode position of focus verb and entity
        embedded_sentence_verb_entity = \
            torch.cat([embedded_sentence, verb_span.float().unsqueeze(-1), entity_span.float().unsqueeze(-1)], dim=-1)

        # Layer 3 = Contextual embedding layer using Bi-LSTM over the sentence
        contextual_embedding = self.seq2seq_encoder(
            embedded_sentence_verb_entity, mask)

        # Layer 4: Attention (Contextual embedding, BOW(verb span))
        verb_weight_matrix = verb_span.float() / (
            verb_span.float().sum(-1).unsqueeze(-1) + 1e-13)
        verb_vector = weighted_sum(
            contextual_embedding * verb_span.float().unsqueeze(-1),
            verb_weight_matrix)
        entity_weight_matrix = entity_span.float() / (
            entity_span.float().sum(-1).unsqueeze(-1) + 1e-13)
        entity_vector = weighted_sum(
            contextual_embedding * entity_span.float().unsqueeze(-1),
            entity_weight_matrix)
        verb_entity_vector = torch.cat([verb_vector, entity_vector], 1)
        batch_size, sequence_length, binary_feature_dim = verb_span.float(
        ).unsqueeze(-1).size()

        # attention weights for type prediction
        attention_weights_types = self.attention_layer(verb_entity_vector,
                                                       contextual_embedding)
        attention_output_vector = weighted_sum(contextual_embedding,
                                               attention_weights_types)

        # contextual embedding + positional vectors for tag prediction
        context_positional_tags = torch.cat([
            contextual_embedding,
            verb_span.float().unsqueeze(-1),
            entity_span.float().unsqueeze(-1)
        ],
                                            dim=-1)

        # Layer 5 = Dense softmax layer to pick one state change type per datapoint,
        # and one tag per word in the sentence
        type_logits = self.aggregate_feedforward(attention_output_vector)
        type_probs = torch.nn.functional.softmax(type_logits, dim=-1)

        tags_logits = self.tag_projection_layer(context_positional_tags)
        reshaped_log_probs = tags_logits.view(-1, self.num_tags)
        tags_class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_tags])

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {'type_probs': type_probs}
        if state_change_type_labels is not None:
            state_change_type_labels_loss = self._loss(
                type_logits,
                state_change_type_labels.long().view(-1))
            for type_label in self.type_labels_vocab.values():
                metric = self.type_f1_metrics["type_" + type_label]
                metric(type_probs, state_change_type_labels.squeeze(-1))

            self._type_accuracy(type_probs,
                                state_change_type_labels.squeeze(-1))

        if state_change_tags is not None:
            state_change_tags_loss = sequence_cross_entropy_with_logits(
                tags_logits, state_change_tags, mask)
            self.span_metric(tags_class_probabilities, state_change_tags, mask)
            output_dict["tags_class_probabilities"] = tags_class_probabilities

        output_dict['loss'] = (state_change_type_labels_loss +
                               state_change_tags_loss)

        return output_dict
Beispiel #24
0
    def forward(self,
                images: torch.Tensor = None,
                objects: torch.LongTensor = None,
                segms: torch.Tensor = None,
                boxes: torch.Tensor = None,
                box_mask: torch.LongTensor = None,
                question: Dict[str, torch.Tensor] = None,
                question_tags: torch.LongTensor = None,
                question_mask: torch.LongTensor = None,
                answers: Dict[str, torch.Tensor] = None,
                answer_tags: torch.LongTensor = None,
                answer_mask: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                label: torch.LongTensor = None,
                bert_input_ids: torch.LongTensor = None,
                bert_input_mask: torch.LongTensor = None,
                bert_input_type_ids: torch.LongTensor = None,
                masked_lm_labels: torch.LongTensor = None,
                is_random_next: torch.LongTensor= None,
                image_text_alignment: torch.LongTensor = None,
                output_all_encoded_layers = False) -> Dict[str, torch.Tensor]:

        # Trim off boxes that are too long. this is an issue b/c dataparallel, it'll pad more zeros that are
        # not needed

        max_len = int(box_mask.sum(1).max().item())
        objects = objects[:, :max_len]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        segms = segms[:, :max_len]
        '''for tag_type, the_tags in (('question', question_tags), ('answer', answer_tags)):
            if int(the_tags.max()) > max_len:
                raise ValueError("Oh no! {}_tags has maximum of {} but objects is of dim {}. Values are\n{}".format(
                    tag_type, int(the_tags.max()), objects.shape, the_tags
                ))'''
        obj_reps = self.detector(images=images, boxes=boxes, box_mask=box_mask, classes=objects, segms=segms)

        #print("obj_reps", obj_reps['obj_reps'].size())
        #print("bert_input_ids", bert_input_ids.size())
        #print("box_mask", box_mask.size())

        if len(bert_input_ids.size()) == 2: # Using complete shuffle mode
            obj_reps_expanded = obj_reps['obj_reps']
            box_mask_expanded = box_mask
        else:
            obj_reps_expanded = obj_reps['obj_reps'].unsqueeze(1).expand(box_mask.size(0), bert_input_mask.size(1), box_mask.size(-1), obj_reps['obj_reps'].size(-1))
            box_mask_expanded = box_mask.unsqueeze(1).expand(box_mask.size(0), bert_input_mask.size(1), box_mask.size(-1))

        #bert_input_mask = torch.cat((bert_input_mask, box_mask_expanded), dim = -1)

        output_dict = self.bert(
            input_ids = bert_input_ids,
            token_type_ids = bert_input_type_ids,
            input_mask = bert_input_mask,

            visual_embeddings = obj_reps_expanded,
            position_embeddings_visual = None,
            image_mask = box_mask_expanded,
            visual_embeddings_type = None,

            image_text_alignment = image_text_alignment,

            label = label,
            masked_lm_labels = masked_lm_labels,
            is_random_next = is_random_next,

            output_all_encoded_layers = output_all_encoded_layers)

        #class_probabilities = F.softmax(logits, dim=-1)
        cnn_loss = obj_reps['cnn_regularization_loss']
        if self.cnn_loss_ratio == 0.0:
            output_dict["cnn_regularization_loss"] = None
        else:
            output_dict["cnn_regularization_loss"] = cnn_loss * self.cnn_loss_ratio

        # Multi-process safe??
        if label is not None and self.training_head_type != "pretraining":
            logits = output_dict["logits"]
            logits = logits.detach().float()
            label = label.float()
            self._accuracy(logits, label)

        if self.training_head_type == "pretraining":
            output_dict["logits"] = None # Because every image may has different number of image features, the lengths of the logits on different GPUs will be different. This will cause DataParallel to throw errors.

        return output_dict
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            targets: torch.LongTensor,
            target_index: torch.LongTensor,
            span_starts: torch.LongTensor,
            span_ends: torch.LongTensor,
            tags: torch.LongTensor = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        bio : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
        tags: shape ``(batch_size, num_spans)``
        span_starts: shape ``(batch_size, num_spans)``
        span_ends: shape ``(batch_size, num_spans)``

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.embedding_dropout(
            self.text_field_embedder(tokens))
        text_mask = util.get_text_field_mask(tokens)

        embedded_verb_indicator = self.binary_feature_embedding(targets.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat(
            [embedded_text_input, embedded_verb_indicator], -1)
        embedding_dim_with_binary_feature = embedded_text_with_verb_indicator.size(
        )[2]

        if self.stacked_encoder.get_input_dim(
        ) != embedding_dim_with_binary_feature:
            raise ConfigurationError(
                "The SRL model uses an indicator feature, which makes "
                "the embedding dimension one larger than the value "
                "specified. Therefore, the 'input_dim' of the stacked_encoder "
                "must be equal to total_embedding_dim + 1.")

        encoded_text = self.stacked_encoder(embedded_text_with_verb_indicator,
                                            text_mask)

        batch_size, num_spans = tags.size()
        assert num_spans % self.max_span_width == 0
        tags = tags.view(batch_size, -1, self.max_span_width)

        span_starts = F.relu(span_starts.float()).long().view(batch_size, -1)
        span_ends = F.relu(span_ends.float()).long().view(batch_size, -1)
        target_index = F.relu(target_index.float()).long().view(batch_size)

        # shape (batch_size, sequence_length * max_span_width, embedding_dim)
        span_embeddings = span_srl_util.compute_span_representations(
            self.max_span_width, encoded_text, target_index, span_starts,
            span_ends, self.span_width_embedding,
            self.span_direction_embedding, self.span_distance_embedding,
            self.span_distance_bin, self.head_scorer)
        span_scores = self.span_feedforward(span_embeddings)

        # FN-specific parameters.
        fn_args = []
        for extra_arg in ['frame', 'valid_frame_elements']:
            if extra_arg in kwargs and kwargs[extra_arg] is not None:
                fn_args.append(kwargs[extra_arg])

        if fn_args:  # FrameSRL batch.
            frame, valid_frame_elements = fn_args
            output_dict = self.compute_srl_graph(
                span_scores=span_scores,
                frame=frame,
                valid_frame_elements=valid_frame_elements,
                tags=tags,
                text_mask=text_mask,
                target_index=target_index)
        else:  # Scaffold batch.
            if "span_mask" in kwargs and kwargs["span_mask"] is not None:
                span_mask = kwargs["span_mask"]
            if "parent_tags" in kwargs and kwargs["parent_tags"] is not None:
                parent_tags = kwargs["parent_tags"]
            if self.unlabeled_constits:
                not_a_constit = self.vocab.get_token_index(
                    "*", self.constit_label_namespace)
                tags = (tags != not_a_constit).float().view(
                    batch_size, -1, self.max_span_width)
            elif self.constit_label_namespace == "parent_labels":
                tags = parent_tags.view(batch_size, -1, self.max_span_width)
            elif self.np_pp_constits:
                tags = self.get_new_tags_np_pp(tags, batch_size)
            output_dict = self.compute_constit_graph(span_mask=span_mask,
                                                     span_scores=span_scores,
                                                     constit_tags=tags,
                                                     text_mask=text_mask)

        if self.fast_mode and not self.training:
            output_dict["loss"] = Variable(torch.FloatTensor([0.00]))

        return output_dict
Beispiel #26
0
    def forward(
            self,
            source_sentences: Dict[str, torch.Tensor],
            sentence_tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        tokens = source_sentences['tokens']
        batch_size = tokens.size(0)
        sentences_count = tokens.size(1)
        max_sentence_length = tokens.size(2)
        tokens = tokens.reshape(batch_size * sentences_count,
                                max_sentence_length)

        sentences_embedding = self._encode({'tokens': tokens})
        sentences_embeddings = sentences_embedding.reshape(
            batch_size, sentences_count, -1)
        sentences_embedding = self._dropout_layer(sentences_embedding)

        h_sentences = self._sentence_accumulator(sentences_embeddings,
                                                 mask=None)
        h_sentences = self._dropout_layer(h_sentences)

        output_dict = dict()
        content = self._content_projection_layer(h_sentences).squeeze(2)
        output_dict['content'] = content
        predictions = content

        if self._use_salience:
            document_embedding = self._document_linear_layer(
                torch.mean(h_sentences), dim=1)
            document_embedding = torch.tanh(document_embedding)
            salience_intermediate = self._salience_linear_layer(
                document_embedding).unsqueeze(2)
            salience = torch.bmm(h_sentences, salience_intermediate).squeeze(2)
            predictions = predictions + salience

        if self._use_pos_embedding:
            assert sentences_count <= self._pos_embedding_num
            position_ids = util.get_range_vector(sentences_count,
                                                 tokens.device.index)
            position_ids = position_ids.unsqueeze(0).expand(
                (batch_size, sentences_count))
            positional_embeddings = self._pos_embeddingLayer(position_ids)
            positional_projection = self._pos_projection_layer(
                positional_embeddings).squeeze(2)
            predictions = predictions + positional_projection
            output_dict['pos'] = positional_projection

        if self._use_novelty:
            summary_representation = sentences_embeddings.new_zeros(
                (batch_size, self._h_sentence_dim))
            novelty = content.new_zeros(
                (batch_size, sentences_count))  # redundancy

            for sentence_num in range(sentences_count):
                novelty_intermediate = self._novelty_linear_layer(
                    torch.tanh(summary_representation)).unsqueeze(2)
                sentence_num_state = h_sentences[:, sentence_num, :]
                novelty[:, sentence_num] = -torch.bmm(
                    sentence_num_state.unsqueeze(1),
                    novelty_intermediate).squeeze(2).squeeze(1)
                predictions[:, sentence_num] += novelty[:, sentence_num]
                probabilities = torch.sigmoid(predictions[:, sentence_num])
                summary_representation += torch.mv(
                    sentence_num_state.transpose(0, 1), probabilities)

            output_dict['novelty'] = novelty

        if self._use_output_bias:
            predictions = predictions + self._output_bias

        output_dict['predicted_tags'] = predictions

        if sentence_tags is not None:
            loss = torch.nn.BCEWithLogitsLoss()(predictions,
                                                sentence_tags.float())
            output_dict['loss'] = loss

        return output_dict
Beispiel #27
0
    def _joint_likelihood(self, logits: torch.Tensor, tags: torch.Tensor,
                          mask: torch.LongTensor) -> torch.Tensor:
        """
        Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
        """
        batch_size, sequence_length, _ = logits.data.shape

        # Transpose batch size and sequence dimensions:
        logits = logits.transpose(0, 1).contiguous()
        mask = mask.float().transpose(0, 1).contiguous()
        tags = tags.transpose(0, 1).contiguous()

        # Start with the transition scores from start_tag to the first tag in each input
        if self.include_start_end_transitions:
            score = self.start_transitions.index_select(0, tags[0])
        else:
            score = 0.0

        # Add up the scores for the observed transitions and all the inputs but the last
        for i in range(sequence_length - 1):
            # Each is shape (batch_size,)
            current_tag, next_tag = tags[i], tags[i + 1]

            # print("current_tag: ", current_tag)
            # print("next_tag: ", next_tag)
            # print("self.transitions: ", self.transitions)
            # print("self.transitions: ", self.transitions.size())

            # The scores for transitioning from current_tag to next_tag
            transition_score = self.transitions[current_tag.view(-1),
                                                next_tag.view(-1)]

            # The score for using current_tag
            emit_score = logits[i].gather(1, current_tag.view(batch_size,
                                                              1)).squeeze(1)

            # Include transition score if next element is unmasked,
            # input_score if this element is unmasked.
            score = score + transition_score * mask[i +
                                                    1] + emit_score * mask[i]

        # Transition from last state to "stop" state. To start with, we need to find the last tag
        # for each instance.
        last_tag_index = mask.sum(0).long() - 1
        last_tags = tags.gather(0, last_tag_index.view(1,
                                                       batch_size)).squeeze(0)

        # Compute score of transitioning to `stop_tag` from each "last tag".
        if self.include_start_end_transitions:
            last_transition_score = self.end_transitions.index_select(
                0, last_tags)
        else:
            last_transition_score = 0.0

        # Add the last input if it's not masked.
        last_inputs = logits[-1]  # (batch_size, num_tags)
        last_input_score = last_inputs.gather(1, last_tags.view(
            -1, 1))  # (batch_size, 1)
        last_input_score = last_input_score.squeeze()  # (batch_size,)

        score = score + last_transition_score + last_input_score * mask[-1]

        return score
Beispiel #28
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor, 
                gold_spans: torch.LongTensor, 
                tags: torch.LongTensor = None,
                span_labels: torch.LongTensor = None,
                gold_span_labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        '''
            tags: Shape(batch_size, seq_len)
                bilou scheme tags for crf modelling
        '''
        
        batch_size = spans.size(0)
        # Adding mask
        mask = util.get_text_field_mask(tokens)

        token_mask = torch.cat([mask, 
                                mask.new_zeros(batch_size, 1)],
                                dim=1)

        embedded_text_input = self.text_field_embedder(tokens)

        embedded_text_input = torch.cat([embedded_text_input, 
                                         embedded_text_input.new_zeros(batch_size, 1, embedded_text_input.size(2))],
                                        dim=1)

        # span_mask Shape: (batch_size, num_spans), 1 or 0
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        gold_span_mask = (gold_spans[:,:,0] >=0).squeeze(-1).float()
        last_span_indices = gold_span_mask.sum(-1,keepdim=True).long()

        batch_indices = torch.arange(batch_size).unsqueeze(-1)
        batch_indices = util.move_to_device(batch_indices, 
                                            util.get_device_of(embedded_text_input))
        last_span_indices = torch.cat([batch_indices, last_span_indices],dim=-1)
        embedded_text_input[last_span_indices[:,0], last_span_indices[:,1]] += self.end_token_embedding.cuda(util.get_device_of(spans))

        token_mask[last_span_indices[:,0], last_span_indices[:,1]] += 1.
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.

        # spans Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()
        gold_spans = F.relu(gold_spans.float()).long()
        num_spans = spans.size(1)
        num_gold_spans = gold_spans.size(1)

        # Shape (batch_size, num_gold_spans, 4)
        hscrf_target = torch.cat([gold_spans, gold_spans.new_zeros(*gold_spans.size())],
                                 dim=-1)
        hscrf_target[:,:,2] = torch.cat([
            (gold_span_labels.new_zeros(batch_size, 1)+self.hscrf_layer.start_id).long(), # start tags in the front
            gold_span_labels.squeeze()[:,0:-1]],
            dim=-1)
        hscrf_target[:,:,3] = gold_span_labels.squeeze()
        # Shape (batch_size, num_gold_spans+1, 4)  including an <end> singular-span
        hscrf_target = torch.cat([hscrf_target, gold_spans.new_zeros(batch_size, 1, 4)],
                                 dim=1)

        hscrf_target[last_span_indices[:,0], last_span_indices[:,1],0:2] = \
                hscrf_target[last_span_indices[:,0], last_span_indices[:,1]-1][:,1:2] + 1

        hscrf_target[last_span_indices[:,0], last_span_indices[:,1],2] = \
                hscrf_target[last_span_indices[:,0], last_span_indices[:,1]-1][:,3]

        hscrf_target[last_span_indices[:,0], last_span_indices[:,1],3] = \
                self.hscrf_layer.stop_id
        
        

        # span_mask Shape: (batch_size, num_spans), 1 or 0
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()

        gold_span_mask = torch.cat([gold_span_mask.float(), 
                                gold_span_mask.new_zeros(batch_size, 1).float()], dim=-1)
        gold_span_mask[last_span_indices[:,0], last_span_indices[:,1]] = 1.


        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.

        # spans Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()
        num_spans = spans.size(1)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, token_mask)

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        if self._feedforward is not None:
            encoded_text = self._feedforward(encoded_text)

        hscrf_neg_log_likelihood = self.hscrf_layer(
            encoded_text, 
            tokens,
            token_mask.sum(-1).squeeze(),
            hscrf_target,
            gold_span_mask
        )

        pred_results = self.hscrf_layer.get_scrf_decode(
            token_mask.sum(-1).squeeze()
        )
        self._span_f1_metric(
            pred_results, 
            [dic['gold_spans'] for dic in metadata],
            sentences=[x["words"] for x in metadata])
        output = {
            "mask": token_mask,
            "loss": hscrf_neg_log_likelihood,
            "results": pred_results
                 }
        
        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]
        return output
    def compute_constrained_marginal(
        self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor, annotation_mask: torch.LongTensor
    ) -> torch.Tensor:
        """
        Note: This function is not used unless you want to compute the marginal probability
        Forward-backward algorithm to compute the marginal probability (in log space)
        Basically, we follow the `backward` algorithm to obtain the backward scores.
        :param lstm_scores:   shape: (batch_size, sent_len, label_size) NOTE: the score from LSTMs, not `all_scores` (which add up the transtiion)
        :param word_seq_lens: shape: (batch_size,)
        :param annotation_mask: shape: (batch_size, sent_len, label_size)
        :return: Marginal score. If you want probability, you need to use `torch.exp` to convert it into probability
                shape: (batch_size, sent_len, label_size)
        """
        batch_size = lstm_scores.size(0)
        seq_len = lstm_scores.size(1)
        mask = annotation_mask.float().log()
        alpha = torch.zeros(batch_size, seq_len, self.label_size).to(self.device)
        beta = torch.zeros(batch_size, seq_len, self.label_size).to(self.device)

        scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(
            batch_size, seq_len, self.label_size, self.label_size
        ) + lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(
            batch_size, seq_len, self.label_size, self.label_size
        )
        ## reverse the view of computing the score. we look from behind
        rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(
            batch_size, seq_len, self.label_size, self.label_size
        ) + lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(
            batch_size, seq_len, self.label_size, self.label_size
        )

        perm_idx = torch.zeros(batch_size, seq_len).to(self.device)
        for batch_idx in range(batch_size):
            perm_idx[batch_idx][: word_seq_lens[batch_idx]] = torch.range(word_seq_lens[batch_idx] - 1, 0, -1)
        perm_idx = perm_idx.long()
        for i, length in enumerate(word_seq_lens):
            rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]]

        alpha[:, 0, :] = scores[
            :, 0, self.start_idx, :
        ]  ## the first position of all labels = (the transition from start - > all labels) + current emission.
        alpha[:, 0, :] += mask[:, 0, :]
        beta[:, 0, :] = rev_score[:, 0, self.end_idx, :]
        for word_idx in range(1, seq_len):
            before_log_sum_exp = (
                alpha[:, word_idx - 1, :]
                .view(batch_size, self.label_size, 1)
                .expand(batch_size, self.label_size, self.label_size)
                + scores[:, word_idx, :, :]
            )
            alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)
            alpha[:, word_idx, :] += mask[:, word_idx, :]
            before_log_sum_exp = (
                beta[:, word_idx - 1, :]
                .view(batch_size, self.label_size, 1)
                .expand(batch_size, self.label_size, self.label_size)
                + rev_score[:, word_idx, :, :]
            )
            beta[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)
            beta[:, word_idx, :] += mask[:, word_idx, :]

        ### batch_size x label_size
        last_alpha = torch.gather(
            alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size) - 1
        ).view(batch_size, self.label_size)
        last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size)
        last_alpha = (
            log_sum_exp_pytorch(last_alpha.view(batch_size, self.label_size, 1))
            .view(batch_size, 1, 1)
            .expand(batch_size, seq_len, self.label_size)
        )

        ## Because we need to use the beta variable later, we need to reverse back
        for i, length in enumerate(word_seq_lens):
            beta[i, :length] = beta[i, :length][perm_idx[i, :length]]

        # `alpha + beta - last_alpha` is the standard way to obtain the marginal
        # However, we have two emission scores overlap at each position, thus, we need to subtract one emission score
        return alpha + beta - last_alpha - lstm_scores
    def forward(
            self,
            passage_attention: torch.Tensor,
            passage_lengths: List[int],
            count_answer: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        device_id = allenutil.get_device_of(passage_attention)

        batch_size, max_passage_length = passage_attention.size()

        # Shape: (B, passage_length)
        passage_mask = (passage_attention >= 0).float()

        # List of (B, P) shaped tensors
        scaled_attentions = [
            passage_attention * sf for sf in self.scaling_vals
        ]
        # Shape: (B, passage_length, num_scaling_factors)
        scaled_passage_attentions = torch.stack(scaled_attentions, dim=2)

        # Shape (batch_size, 1)
        passage_len_bias = self.passagelength_to_bias(
            passage_mask.sum(1, keepdim=True))

        scaled_passage_attentions = scaled_passage_attentions * passage_mask.unsqueeze(
            2)

        # Shape: (B, passage_length, hidden_dim)
        count_hidden_repr = self.passage_attention_to_count(
            scaled_passage_attentions, passage_mask)

        # Shape: (B, passage_length, 1) -- score for each token
        passage_span_logits = self.passage_count_hidden2logits(
            count_hidden_repr)
        # Shape: (B, passage_length) -- sigmoid on token-score
        token_sigmoids = torch.sigmoid(passage_span_logits.squeeze(2))
        token_sigmoids = token_sigmoids * passage_mask

        # Shape: (B, 1) -- sum of sigmoids. This will act as the predicted mean
        # passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True) + passage_len_bias
        passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True)

        # Shape: (1, count_vals)
        self.countvals = allenutil.get_range_vector(
            10, device=device_id).unsqueeze(0).float()

        variance = 0.2

        # Shape: (batch_size, count_vals)
        l2_by_vsquared = torch.pow(self.countvals - passage_count_mean,
                                   2) / (2 * variance * variance)
        exp_val = torch.exp(-1 * l2_by_vsquared) + 1e-30
        # Shape: (batch_size, count_vals)
        count_distribution = exp_val / (torch.sum(exp_val, 1, keepdim=True))

        # Loss computation
        output_dict = {}
        loss = 0.0
        pred_count_idx = torch.argmax(count_distribution, 1)
        if count_answer is not None:
            # L2-loss
            passage_count_mean = passage_count_mean.squeeze(1)
            L2Loss = F.mse_loss(input=passage_count_mean,
                                target=count_answer.float())
            loss = L2Loss
            predictions = passage_count_mean.detach().cpu().numpy()
            predictions = np.round_(predictions)

            gold_count = count_answer.detach().cpu().numpy()
            correct_vec = (predictions == gold_count)
            correct_perc = sum(correct_vec) / batch_size
            # print(f"{correct_perc} {predictions} {gold_count}")
            self.count_acc(correct_perc)

            # loss = F.cross_entropy(input=count_distribution, target=count_answer)
            # List of predicted count idxs, Shape: (B,)
            # correct_vec = (pred_count_idx == count_answer).float()
            # correct_perc = torch.sum(correct_vec) / batch_size
            # self.count_acc(correct_perc.item())

        batch_loss = loss / batch_size
        output_dict["loss"] = batch_loss
        output_dict["passage_attention"] = passage_attention
        output_dict["passage_sigmoid"] = token_sigmoids
        output_dict["count_mean"] = passage_count_mean
        output_dict["count_distritbuion"] = count_distribution
        output_dict["count_answer"] = count_answer
        output_dict["pred_count"] = pred_count_idx

        return output_dict
Beispiel #31
0
    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        grammar_value_logits = self._gram_val_output(encoded_text)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

        # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ
        #l_ext_input = encoded_text
        l_ext_input = torch.cat([encoded_text, grammar_value_logits], -1)
        lemma_logits = self._lemma_output(l_ext_input)
        predicted_lemmas = lemma_logits.argmax(-1)

        # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ВАРИАНТОВ ЛЕММАТИЗАЦИИ И ОЦЕНОК ВЕРОЯТНОСТИ
        lemma_probs = torch.nn.functional.softmax(lemma_logits, -1)
        top_lemmas_indices = (-lemma_logits).argsort(-1)[:, :, :self.TopNCnt]
        #top_lemmas_indices = (-lemma_probs).argsort(-1)[:,:,:self.TopNCnt]
        top_lemmas_prob = torch.gather(lemma_probs, -1, top_lemmas_indices)
        #top_lemmas_prob = torch.gather(lemma_logits, -1, top_lemmas_indices)

        # АНАЛОГИЧНО ДЛЯ ГРАММЕМ
        gramm_probs = torch.nn.functional.softmax(grammar_value_logits, -1)
        top_gramms_indices = (
            -grammar_value_logits).argsort(-1)[:, :, :self.TopNCnt]
        top_gramms_prob = torch.gather(gramm_probs, -1, top_gramms_indices)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            synt_prediction, benrg = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
            predicted_heads, predicted_head_tags = synt_prediction

        # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ЛОКАЛЬНЫХ!!! (не mst) ВАРИАНТОВ СИНТАКСИЧЕСКОГО РАЗБОРА И ОЦЕНОК ВЕРОЯТНОСИ
        benrgf = torch.flatten(benrg, start_dim=1, end_dim=2).permute(
            0, 2, 1)  # склеивает тип синт. отношения с индексом родителя
        top_deprels_indices = (-benrgf).argsort(
            -1)[:, :, :self.TopNCnt]  # отбираем наилучшие комбинации
        top_deprels_prob = torch.gather(benrgf, -1, top_deprels_indices)
        seqlen = benrg.shape[2]
        top_heads = torch.fmod(top_deprels_indices, seqlen)
        top_deprels_indices = torch.div(top_deprels_indices,
                                        seqlen)  # torch.floor не срабатывает

        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            grammar_nll = self._update_multiclass_prediction_metrics(
                logits=grammar_value_logits,
                targets=grammar_values,
                mask=token_mask,
                accuracy_metric=self._gram_val_prediction_accuracy)

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            lemma_nll = self._update_multiclass_prediction_metrics(
                logits=lemma_logits,
                targets=lemma_indices,
                mask=token_mask,
                accuracy_metric=self._lemma_prediction_accuracy,
                masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX)

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
            "top_lemmas": top_lemmas_indices,
            "top_lemmas_prob": top_lemmas_prob,
            "top_gramms": top_gramms_indices,
            "top_gramms_prob": top_gramms_prob,
            "top_heads": top_heads,
            "top_deprels": top_deprels_indices,
            "top_deprels_prob": top_deprels_prob,
        }

        return output_dict
Beispiel #32
0
    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.LongTensor = None,
            span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(int(
            self._input_dim / 2),
                                                                    dim=-1)
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [
            index.squeeze(-1) for index in span_indices.split(1, dim=-1)
        ]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP ``SpanField`` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (
            exclusive_span_starts == -1).long().unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(
                sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = (
                torch.ones_like(sequence_tensor[:, 0, 0], dtype=torch.long) *
                sequence_tensor.size(1))

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(
            -1)).long().unsqueeze(-1)

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * (
            1 - end_sentinel_mask.squeeze(-1))
        exclusive_span_starts = exclusive_span_starts * (
            1 - start_sentinel_mask.squeeze(-1))

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (
                exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
            raise ValueError(
                f"Adjusted span indices must lie inside the length of the sequence tensor, "
                f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                f"{sequence_lengths}.")

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(
            forward_sequence, exclusive_span_starts)
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(
            forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(
            backward_sequence, exclusive_span_ends)
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(
            backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            float_end_sentinel_mask = end_sentinel_mask.float()
            float_start_sentinel_mask = start_sentinel_mask.float()
            forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel
            backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \
                                        + float_end_sentinel_mask * self._end_sentinel

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(
            self._forward_combination,
            [forward_start_embeddings, forward_end_embeddings])
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(
            self._backward_combination,
            [backward_start_embeddings, backward_end_embeddings])
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(
                    span_ends - span_starts,
                    num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.float().unsqueeze(-1)
        return span_embeddings
    def _joint_likelihood(self,
                          logits: torch.Tensor,
                          tags: torch.Tensor,
                          mask: torch.LongTensor) -> torch.Tensor:
        """
        Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
        """
        batch_size, sequence_length, num_tags = logits.data.shape

        # Transpose batch size and sequence dimensions:
        logits = logits.transpose(0, 1).contiguous()
        mask = mask.float().transpose(0, 1).contiguous()
        tags = tags.transpose(0, 1).contiguous()

        # Start with the transition scores from start_tag to the first tag in each input
        if self.include_start_end_transitions:
            score = self.start_transitions.index_select(0, tags[0])
        else:
            score = 0.0

        # Broadcast the transition scores to one per batch element
        broadcast_transitions = self.transitions.view(1, num_tags, num_tags).expand(batch_size, num_tags, num_tags)

        # Add up the scores for the observed transitions and all the inputs but the last
        for i in range(sequence_length - 1):
            # Each is shape (batch_size,)
            current_tag, next_tag = tags[i], tags[i+1]

            # The scores for transitioning from current_tag to next_tag
            transition_score = (
                    broadcast_transitions
                    # Choose the current_tag-th row for each input
                    .gather(1, current_tag.view(batch_size, 1, 1).expand(batch_size, 1, num_tags))
                    # Squeeze down to (batch_size, num_tags)
                    .squeeze(1)
                    # Then choose the next_tag-th column for each of those
                    .gather(1, next_tag.view(batch_size, 1))
                    # And squeeze down to (batch_size,)
                    .squeeze(1)
            )

            # The score for using current_tag
            emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)

            # Include transition score if next element is unmasked,
            # input_score if this element is unmasked.
            score = score + transition_score * mask[i + 1] + emit_score * mask[i]

        # Transition from last state to "stop" state. To start with, we need to find the last tag
        # for each instance.
        last_tag_index = mask.sum(0).long() - 1
        last_tags = tags.gather(0, last_tag_index.view(1, batch_size).expand(sequence_length, batch_size))

        # Is (sequence_length, batch_size), but all the columns are the same, so take the first.
        last_tags = last_tags[0]

        # Compute score of transitioning to `stop_tag` from each "last tag".
        if self.include_start_end_transitions:
            last_transition_score = self.end_transitions.index_select(0, last_tags)
        else:
            last_transition_score = 0.0

        # Add the last input if it's not masked.
        last_inputs = logits[-1]                                         # (batch_size, num_tags)
        last_input_score = last_inputs.gather(1, last_tags.view(-1, 1))  # (batch_size, 1)
        last_input_score = last_input_score.squeeze()                    # (batch_size,)

        score = score + last_transition_score + last_input_score * mask[-1]

        return score