示例#1
0
    def forward(self,
                left_contexts: TextFieldTensors,
                right_contexts: TextFieldTensors,
                targets: TextFieldTensors,
                target_sentiments: torch.LongTensor = None,
                metadata: torch.LongTensor = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        '''
        The text and targets are Dictionaries as they are text fields they can 
        be represented many different ways e.g. just words or words and chars 
        etc therefore the dictionary represents these different ways e.g. 
        {'words': words_tensor_ids, 'chars': char_tensor_ids}
        '''
        # This is required if the input is of shape greater than 3 dim e.g.
        # character input where it is
        # (batch size, number targets, token length, char length)
        targets_mask = util.get_text_field_mask(targets, num_wrapping_dims=1)
        targets_mask = (targets_mask.sum(dim=-1) >= 1).type(torch.int64)
        batch_size, number_targets = targets_mask.shape
        batch_size_num_targets = batch_size * number_targets

        temp_left_contexts = elmo_input_reshape(left_contexts, batch_size,
                                                number_targets,
                                                batch_size_num_targets)
        left_embedded_text = self.context_field_embedder(temp_left_contexts)
        left_embedded_text = elmo_input_reverse(left_embedded_text,
                                                left_contexts, batch_size,
                                                number_targets,
                                                batch_size_num_targets)
        left_embedded_text = self._time_variational_dropout(left_embedded_text)
        left_text_mask = util.get_text_field_mask(left_contexts,
                                                  num_wrapping_dims=1)

        temp_right_contexts = elmo_input_reshape(right_contexts, batch_size,
                                                 number_targets,
                                                 batch_size_num_targets)
        right_embedded_text = self.context_field_embedder(temp_right_contexts)
        right_embedded_text = elmo_input_reverse(right_embedded_text,
                                                 right_contexts, batch_size,
                                                 number_targets,
                                                 batch_size_num_targets)
        right_embedded_text = self._time_variational_dropout(
            right_embedded_text)
        right_text_mask = util.get_text_field_mask(right_contexts,
                                                   num_wrapping_dims=1)
        if self.target_encoder:
            temp_target = elmo_input_reshape(targets, batch_size,
                                             number_targets,
                                             batch_size_num_targets)
            if self.target_field_embedder:
                embedded_target = self.target_field_embedder(temp_target)
            else:
                embedded_target = self.context_field_embedder(temp_target)
            embedded_target = elmo_input_reverse(embedded_target, targets,
                                                 batch_size, number_targets,
                                                 batch_size_num_targets)
            embedded_target = self._time_variational_dropout(embedded_target)
            target_text_mask = util.get_text_field_mask(targets,
                                                        num_wrapping_dims=1)

            target_encoded_text = self.target_encoder(embedded_target,
                                                      target_text_mask)
            target_encoded_text = self._naive_dropout(target_encoded_text)
            # Encoded target to be of dimension (batch, Number of Targets, words, dim)
            # currently (batch, Number of Targets, dim)
            target_encoded_text = target_encoded_text.unsqueeze(2)

            # Need to repeat the target word for each word in the left
            # and right word.
            left_num_padded = left_embedded_text.shape[2]
            right_num_padded = right_embedded_text.shape[2]

            left_targets = target_encoded_text.repeat(
                (1, 1, left_num_padded, 1))
            right_targets = target_encoded_text.repeat(
                (1, 1, right_num_padded, 1))
            # Add the target to each word in the left and right contexts
            left_embedded_text = torch.cat((left_embedded_text, left_targets),
                                           -1)
            right_embedded_text = torch.cat(
                (right_embedded_text, right_targets), -1)

        left_encoded_text = self.left_text_encoder(left_embedded_text,
                                                   left_text_mask)
        left_encoded_text = self._naive_dropout(left_encoded_text)

        right_encoded_text = self.right_text_encoder(right_embedded_text,
                                                     right_text_mask)
        right_encoded_text = self._naive_dropout(right_encoded_text)

        encoded_left_right = torch.cat([left_encoded_text, right_encoded_text],
                                       dim=-1)

        if self.inter_target_encoding is not None:
            encoded_left_right = self.inter_target_encoding(
                encoded_left_right, targets_mask)
            encoded_left_right = self._variational_dropout(encoded_left_right)

        if self.feedforward:
            encoded_left_right = self.feedforward(encoded_left_right)
        logits = self.label_projection(encoded_left_right)

        masked_class_probabilities = util.masked_softmax(
            logits, targets_mask.unsqueeze(-1))

        output_dict = {
            "class_probabilities": masked_class_probabilities,
            "targets_mask": targets_mask
        }
        # Convert it to bool tensor.
        targets_mask = targets_mask == 1

        if target_sentiments is not None:
            # gets the loss per target instance due to the average=`token`
            if self.loss_weights is not None:
                loss = util.sequence_cross_entropy_with_logits(
                    logits,
                    target_sentiments,
                    targets_mask,
                    average='token',
                    alpha=self.loss_weights)
            else:
                loss = util.sequence_cross_entropy_with_logits(
                    logits, target_sentiments, targets_mask, average='token')
            for metrics in [self.metrics, self.f1_metrics]:
                for metric in metrics.values():
                    metric(logits, target_sentiments, targets_mask)
            output_dict["loss"] = loss

        if metadata is not None:
            words = []
            texts = []
            targets = []
            target_words = []
            for sample in metadata:
                words.append(sample['text words'])
                texts.append(sample['text'])
                targets.append(sample['targets'])
                target_words.append(sample['target words'])
            output_dict["words"] = words
            output_dict["text"] = texts
            output_dict["targets"] = targets
            output_dict["target words"] = target_words

        return output_dict
示例#2
0
    def forward(self,
                tokens: TextFieldTensors,
                targets: TextFieldTensors,
                target_sentiments: torch.LongTensor = None,
                target_sequences: Optional[torch.LongTensor] = None,
                metadata: torch.LongTensor = None,
                position_weights: Optional[torch.LongTensor] = None,
                position_embeddings: Optional[Dict[str,
                                                   torch.LongTensor]] = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        '''
        The text and targets are Dictionaries as they are text fields they can 
        be represented many different ways e.g. just words or words and chars 
        etc therefore the dictionary represents these different ways e.g. 
        {'words': words_tensor_ids, 'chars': char_tensor_ids}
        '''
        # Get masks for the targets before they get manipulated
        targets_mask = util.get_text_field_mask(targets, num_wrapping_dims=1)
        # This is required if the input is of shape greater than 3 dim e.g.
        # character input where it is
        # (batch size, number targets, token length, char length)
        label_mask = (targets_mask.sum(dim=-1) >= 1).type(torch.int64)
        batch_size, number_targets = label_mask.shape
        batch_size_num_targets = batch_size * number_targets

        # Embed and encode text as a sequence
        embedded_context = self.context_field_embedder(tokens)
        embedded_context = self._variational_dropout(embedded_context)
        context_mask = util.get_text_field_mask(tokens)
        # Need to repeat the so it is of shape:
        # (Batch Size * Number Targets, Sequence Length, Dim) Currently:
        # (Batch Size, Sequence Length, Dim)
        batch_size, context_sequence_length, context_embed_dim = embedded_context.shape
        reshaped_embedding_context = embedded_context.unsqueeze(1).repeat(
            1, number_targets, 1, 1)
        reshaped_embedding_context = reshaped_embedding_context.view(
            batch_size_num_targets, context_sequence_length, context_embed_dim)
        # Embed and encode target as a sequence. If True here the target
        # embeddings come from the context.
        if self._use_target_sequences:
            _, _, target_sequence_length, target_index_length = target_sequences.shape
            target_index_len_err = (
                'The size of the context sequence '
                f'{context_sequence_length} is not the same'
                ' as the target index sequence '
                f'{target_index_length}. This is to get '
                'the contextualized target through the context')
            assert context_sequence_length == target_index_length, target_index_len_err
            seq_targets_mask = target_sequences.view(batch_size_num_targets,
                                                     target_sequence_length,
                                                     target_index_length)
            reshaped_embedding_targets = torch.matmul(
                seq_targets_mask.type(torch.float32),
                reshaped_embedding_context)
        else:
            temp_targets = elmo_input_reshape(targets, batch_size,
                                              number_targets,
                                              batch_size_num_targets)
            if self.target_field_embedder:
                embedded_targets = self.target_field_embedder(temp_targets)
            else:
                embedded_targets = self.context_field_embedder(temp_targets)
                embedded_targets = elmo_input_reverse(embedded_targets,
                                                      targets, batch_size,
                                                      number_targets,
                                                      batch_size_num_targets)

            # Size (batch size, num targets, target sequence length, embedding dim)
            embedded_targets = self._time_variational_dropout(embedded_targets)
            batch_size, number_targets, target_sequence_length, target_embed_dim = embedded_targets.shape
            reshaped_embedding_targets = embedded_targets.view(
                batch_size_num_targets, target_sequence_length,
                target_embed_dim)

        encoded_targets_mask = targets_mask.view(batch_size_num_targets,
                                                 target_sequence_length)
        # Shape (Batch Size * Number targets), encoded dim
        encoded_targets_seq = self.target_encoder(reshaped_embedding_targets,
                                                  encoded_targets_mask)
        encoded_targets_seq = self._naive_dropout(encoded_targets_seq)

        repeated_context_mask = context_mask.unsqueeze(1).repeat(
            1, number_targets, 1)
        repeated_context_mask = repeated_context_mask.view(
            batch_size_num_targets, context_sequence_length)
        # Need to concat the target embeddings to the context words
        repeated_encoded_targets = encoded_targets_seq.unsqueeze(1).repeat(
            1, context_sequence_length, 1)
        if self._AE:
            reshaped_embedding_context = torch.cat(
                (reshaped_embedding_context, repeated_encoded_targets), -1)
        # add position embeddings if required.
        reshaped_embedding_context = concat_position_embeddings(
            reshaped_embedding_context, position_embeddings,
            self.target_position_embedding)
        # Size (batch size * number targets, sequence length, embedding dim)
        reshaped_encoded_context_seq = self.context_encoder(
            reshaped_embedding_context, repeated_context_mask)
        reshaped_encoded_context_seq = self._variational_dropout(
            reshaped_encoded_context_seq)
        # Weighted position information encoded into the context sequence.
        if self.target_position_weight is not None:
            if position_weights is None:
                raise ValueError(
                    'This model requires `position_weights` to '
                    'better encode the target but none were given')
            position_output = self.target_position_weight(
                reshaped_encoded_context_seq, position_weights,
                repeated_context_mask)
            reshaped_encoded_context_seq, weighted_position_weights = position_output
        # Whether to concat the aspect embeddings on to the contextualised word
        # representations
        attention_encoded_context_seq = reshaped_encoded_context_seq
        if self._AttentionAE:
            attention_encoded_context_seq = torch.cat(
                (attention_encoded_context_seq, repeated_encoded_targets), -1)
        _, _, attention_encoded_dim = attention_encoded_context_seq.shape

        # Projection layer before the attention layer
        attention_encoded_context_seq = self.attention_project_layer(
            attention_encoded_context_seq)
        attention_encoded_context_seq = self._context_attention_activation_function(
            attention_encoded_context_seq)
        attention_encoded_context_seq = self._variational_dropout(
            attention_encoded_context_seq)

        # Attention over the context sequence
        attention_vector = self.attention_vector.unsqueeze(0).repeat(
            batch_size_num_targets, 1)
        attention_weights = self.context_attention_layer(
            attention_vector, attention_encoded_context_seq,
            repeated_context_mask)
        expanded_attention_weights = attention_weights.unsqueeze(-1)
        weighted_encoded_context_seq = reshaped_encoded_context_seq * expanded_attention_weights
        weighted_encoded_context_vec = weighted_encoded_context_seq.sum(dim=1)

        # Add the last hidden state of the context vector, with the attention vector
        context_final_states = util.get_final_encoder_states(
            reshaped_encoded_context_seq,
            repeated_context_mask,
            bidirectional=self.context_encoder_bidirectional)
        context_final_states = self.final_hidden_state_projection_layer(
            context_final_states)
        weighted_encoded_context_vec = self.final_attention_projection_layer(
            weighted_encoded_context_vec)
        feature_vector = context_final_states + weighted_encoded_context_vec
        feature_vector = self._naive_dropout(feature_vector)
        # Reshape the vector into (Batch Size, Number Targets, number labels)
        _, feature_dim = feature_vector.shape
        feature_target_seq = feature_vector.view(batch_size, number_targets,
                                                 feature_dim)

        if self.inter_target_encoding is not None:
            feature_target_seq = self.inter_target_encoding(
                feature_target_seq, label_mask)
            feature_target_seq = self._variational_dropout(feature_target_seq)

        if self.feedforward is not None:
            feature_target_seq = self.feedforward(feature_target_seq)

        logits = self.label_projection(feature_target_seq)
        masked_class_probabilities = util.masked_softmax(
            logits, label_mask.unsqueeze(-1))
        output_dict = {
            "class_probabilities": masked_class_probabilities,
            "targets_mask": label_mask
        }
        # Convert it to bool tensor.
        label_mask = label_mask == 1

        if target_sentiments is not None:
            # gets the loss per target instance due to the average=`token`
            if self.loss_weights is not None:
                loss = util.sequence_cross_entropy_with_logits(
                    logits,
                    target_sentiments,
                    label_mask,
                    average='token',
                    alpha=self.loss_weights)
            else:
                loss = util.sequence_cross_entropy_with_logits(
                    logits, target_sentiments, label_mask, average='token')
            for metrics in [self.metrics, self.f1_metrics]:
                for metric in metrics.values():
                    metric(logits, target_sentiments, label_mask)
            output_dict["loss"] = loss

        if metadata is not None:
            words = []
            texts = []
            targets = []
            target_words = []
            for batch_index, sample in enumerate(metadata):
                words.append(sample['text words'])
                texts.append(sample['text'])
                targets.append(sample['targets'])
                target_words.append(sample['target words'])

            output_dict["words"] = words
            output_dict["text"] = texts
            word_attention_weights = attention_weights.view(
                batch_size, number_targets, context_sequence_length)
            output_dict["word_attention"] = word_attention_weights
            output_dict["targets"] = targets
            output_dict["target words"] = target_words
            output_dict["context_mask"] = context_mask

        return output_dict