def forward(self, left_contexts: TextFieldTensors, right_contexts: TextFieldTensors, targets: TextFieldTensors, target_sentiments: torch.LongTensor = None, metadata: torch.LongTensor = None, **kwargs) -> Dict[str, torch.Tensor]: ''' The text and targets are Dictionaries as they are text fields they can be represented many different ways e.g. just words or words and chars etc therefore the dictionary represents these different ways e.g. {'words': words_tensor_ids, 'chars': char_tensor_ids} ''' # This is required if the input is of shape greater than 3 dim e.g. # character input where it is # (batch size, number targets, token length, char length) targets_mask = util.get_text_field_mask(targets, num_wrapping_dims=1) targets_mask = (targets_mask.sum(dim=-1) >= 1).type(torch.int64) batch_size, number_targets = targets_mask.shape batch_size_num_targets = batch_size * number_targets temp_left_contexts = elmo_input_reshape(left_contexts, batch_size, number_targets, batch_size_num_targets) left_embedded_text = self.context_field_embedder(temp_left_contexts) left_embedded_text = elmo_input_reverse(left_embedded_text, left_contexts, batch_size, number_targets, batch_size_num_targets) left_embedded_text = self._time_variational_dropout(left_embedded_text) left_text_mask = util.get_text_field_mask(left_contexts, num_wrapping_dims=1) temp_right_contexts = elmo_input_reshape(right_contexts, batch_size, number_targets, batch_size_num_targets) right_embedded_text = self.context_field_embedder(temp_right_contexts) right_embedded_text = elmo_input_reverse(right_embedded_text, right_contexts, batch_size, number_targets, batch_size_num_targets) right_embedded_text = self._time_variational_dropout( right_embedded_text) right_text_mask = util.get_text_field_mask(right_contexts, num_wrapping_dims=1) if self.target_encoder: temp_target = elmo_input_reshape(targets, batch_size, number_targets, batch_size_num_targets) if self.target_field_embedder: embedded_target = self.target_field_embedder(temp_target) else: embedded_target = self.context_field_embedder(temp_target) embedded_target = elmo_input_reverse(embedded_target, targets, batch_size, number_targets, batch_size_num_targets) embedded_target = self._time_variational_dropout(embedded_target) target_text_mask = util.get_text_field_mask(targets, num_wrapping_dims=1) target_encoded_text = self.target_encoder(embedded_target, target_text_mask) target_encoded_text = self._naive_dropout(target_encoded_text) # Encoded target to be of dimension (batch, Number of Targets, words, dim) # currently (batch, Number of Targets, dim) target_encoded_text = target_encoded_text.unsqueeze(2) # Need to repeat the target word for each word in the left # and right word. left_num_padded = left_embedded_text.shape[2] right_num_padded = right_embedded_text.shape[2] left_targets = target_encoded_text.repeat( (1, 1, left_num_padded, 1)) right_targets = target_encoded_text.repeat( (1, 1, right_num_padded, 1)) # Add the target to each word in the left and right contexts left_embedded_text = torch.cat((left_embedded_text, left_targets), -1) right_embedded_text = torch.cat( (right_embedded_text, right_targets), -1) left_encoded_text = self.left_text_encoder(left_embedded_text, left_text_mask) left_encoded_text = self._naive_dropout(left_encoded_text) right_encoded_text = self.right_text_encoder(right_embedded_text, right_text_mask) right_encoded_text = self._naive_dropout(right_encoded_text) encoded_left_right = torch.cat([left_encoded_text, right_encoded_text], dim=-1) if self.inter_target_encoding is not None: encoded_left_right = self.inter_target_encoding( encoded_left_right, targets_mask) encoded_left_right = self._variational_dropout(encoded_left_right) if self.feedforward: encoded_left_right = self.feedforward(encoded_left_right) logits = self.label_projection(encoded_left_right) masked_class_probabilities = util.masked_softmax( logits, targets_mask.unsqueeze(-1)) output_dict = { "class_probabilities": masked_class_probabilities, "targets_mask": targets_mask } # Convert it to bool tensor. targets_mask = targets_mask == 1 if target_sentiments is not None: # gets the loss per target instance due to the average=`token` if self.loss_weights is not None: loss = util.sequence_cross_entropy_with_logits( logits, target_sentiments, targets_mask, average='token', alpha=self.loss_weights) else: loss = util.sequence_cross_entropy_with_logits( logits, target_sentiments, targets_mask, average='token') for metrics in [self.metrics, self.f1_metrics]: for metric in metrics.values(): metric(logits, target_sentiments, targets_mask) output_dict["loss"] = loss if metadata is not None: words = [] texts = [] targets = [] target_words = [] for sample in metadata: words.append(sample['text words']) texts.append(sample['text']) targets.append(sample['targets']) target_words.append(sample['target words']) output_dict["words"] = words output_dict["text"] = texts output_dict["targets"] = targets output_dict["target words"] = target_words return output_dict
def forward(self, tokens: TextFieldTensors, targets: TextFieldTensors, target_sentiments: torch.LongTensor = None, target_sequences: Optional[torch.LongTensor] = None, metadata: torch.LongTensor = None, position_weights: Optional[torch.LongTensor] = None, position_embeddings: Optional[Dict[str, torch.LongTensor]] = None, **kwargs) -> Dict[str, torch.Tensor]: ''' The text and targets are Dictionaries as they are text fields they can be represented many different ways e.g. just words or words and chars etc therefore the dictionary represents these different ways e.g. {'words': words_tensor_ids, 'chars': char_tensor_ids} ''' # Get masks for the targets before they get manipulated targets_mask = util.get_text_field_mask(targets, num_wrapping_dims=1) # This is required if the input is of shape greater than 3 dim e.g. # character input where it is # (batch size, number targets, token length, char length) label_mask = (targets_mask.sum(dim=-1) >= 1).type(torch.int64) batch_size, number_targets = label_mask.shape batch_size_num_targets = batch_size * number_targets # Embed and encode text as a sequence embedded_context = self.context_field_embedder(tokens) embedded_context = self._variational_dropout(embedded_context) context_mask = util.get_text_field_mask(tokens) # Need to repeat the so it is of shape: # (Batch Size * Number Targets, Sequence Length, Dim) Currently: # (Batch Size, Sequence Length, Dim) batch_size, context_sequence_length, context_embed_dim = embedded_context.shape reshaped_embedding_context = embedded_context.unsqueeze(1).repeat( 1, number_targets, 1, 1) reshaped_embedding_context = reshaped_embedding_context.view( batch_size_num_targets, context_sequence_length, context_embed_dim) # Embed and encode target as a sequence. If True here the target # embeddings come from the context. if self._use_target_sequences: _, _, target_sequence_length, target_index_length = target_sequences.shape target_index_len_err = ( 'The size of the context sequence ' f'{context_sequence_length} is not the same' ' as the target index sequence ' f'{target_index_length}. This is to get ' 'the contextualized target through the context') assert context_sequence_length == target_index_length, target_index_len_err seq_targets_mask = target_sequences.view(batch_size_num_targets, target_sequence_length, target_index_length) reshaped_embedding_targets = torch.matmul( seq_targets_mask.type(torch.float32), reshaped_embedding_context) else: temp_targets = elmo_input_reshape(targets, batch_size, number_targets, batch_size_num_targets) if self.target_field_embedder: embedded_targets = self.target_field_embedder(temp_targets) else: embedded_targets = self.context_field_embedder(temp_targets) embedded_targets = elmo_input_reverse(embedded_targets, targets, batch_size, number_targets, batch_size_num_targets) # Size (batch size, num targets, target sequence length, embedding dim) embedded_targets = self._time_variational_dropout(embedded_targets) batch_size, number_targets, target_sequence_length, target_embed_dim = embedded_targets.shape reshaped_embedding_targets = embedded_targets.view( batch_size_num_targets, target_sequence_length, target_embed_dim) encoded_targets_mask = targets_mask.view(batch_size_num_targets, target_sequence_length) # Shape (Batch Size * Number targets), encoded dim encoded_targets_seq = self.target_encoder(reshaped_embedding_targets, encoded_targets_mask) encoded_targets_seq = self._naive_dropout(encoded_targets_seq) repeated_context_mask = context_mask.unsqueeze(1).repeat( 1, number_targets, 1) repeated_context_mask = repeated_context_mask.view( batch_size_num_targets, context_sequence_length) # Need to concat the target embeddings to the context words repeated_encoded_targets = encoded_targets_seq.unsqueeze(1).repeat( 1, context_sequence_length, 1) if self._AE: reshaped_embedding_context = torch.cat( (reshaped_embedding_context, repeated_encoded_targets), -1) # add position embeddings if required. reshaped_embedding_context = concat_position_embeddings( reshaped_embedding_context, position_embeddings, self.target_position_embedding) # Size (batch size * number targets, sequence length, embedding dim) reshaped_encoded_context_seq = self.context_encoder( reshaped_embedding_context, repeated_context_mask) reshaped_encoded_context_seq = self._variational_dropout( reshaped_encoded_context_seq) # Weighted position information encoded into the context sequence. if self.target_position_weight is not None: if position_weights is None: raise ValueError( 'This model requires `position_weights` to ' 'better encode the target but none were given') position_output = self.target_position_weight( reshaped_encoded_context_seq, position_weights, repeated_context_mask) reshaped_encoded_context_seq, weighted_position_weights = position_output # Whether to concat the aspect embeddings on to the contextualised word # representations attention_encoded_context_seq = reshaped_encoded_context_seq if self._AttentionAE: attention_encoded_context_seq = torch.cat( (attention_encoded_context_seq, repeated_encoded_targets), -1) _, _, attention_encoded_dim = attention_encoded_context_seq.shape # Projection layer before the attention layer attention_encoded_context_seq = self.attention_project_layer( attention_encoded_context_seq) attention_encoded_context_seq = self._context_attention_activation_function( attention_encoded_context_seq) attention_encoded_context_seq = self._variational_dropout( attention_encoded_context_seq) # Attention over the context sequence attention_vector = self.attention_vector.unsqueeze(0).repeat( batch_size_num_targets, 1) attention_weights = self.context_attention_layer( attention_vector, attention_encoded_context_seq, repeated_context_mask) expanded_attention_weights = attention_weights.unsqueeze(-1) weighted_encoded_context_seq = reshaped_encoded_context_seq * expanded_attention_weights weighted_encoded_context_vec = weighted_encoded_context_seq.sum(dim=1) # Add the last hidden state of the context vector, with the attention vector context_final_states = util.get_final_encoder_states( reshaped_encoded_context_seq, repeated_context_mask, bidirectional=self.context_encoder_bidirectional) context_final_states = self.final_hidden_state_projection_layer( context_final_states) weighted_encoded_context_vec = self.final_attention_projection_layer( weighted_encoded_context_vec) feature_vector = context_final_states + weighted_encoded_context_vec feature_vector = self._naive_dropout(feature_vector) # Reshape the vector into (Batch Size, Number Targets, number labels) _, feature_dim = feature_vector.shape feature_target_seq = feature_vector.view(batch_size, number_targets, feature_dim) if self.inter_target_encoding is not None: feature_target_seq = self.inter_target_encoding( feature_target_seq, label_mask) feature_target_seq = self._variational_dropout(feature_target_seq) if self.feedforward is not None: feature_target_seq = self.feedforward(feature_target_seq) logits = self.label_projection(feature_target_seq) masked_class_probabilities = util.masked_softmax( logits, label_mask.unsqueeze(-1)) output_dict = { "class_probabilities": masked_class_probabilities, "targets_mask": label_mask } # Convert it to bool tensor. label_mask = label_mask == 1 if target_sentiments is not None: # gets the loss per target instance due to the average=`token` if self.loss_weights is not None: loss = util.sequence_cross_entropy_with_logits( logits, target_sentiments, label_mask, average='token', alpha=self.loss_weights) else: loss = util.sequence_cross_entropy_with_logits( logits, target_sentiments, label_mask, average='token') for metrics in [self.metrics, self.f1_metrics]: for metric in metrics.values(): metric(logits, target_sentiments, label_mask) output_dict["loss"] = loss if metadata is not None: words = [] texts = [] targets = [] target_words = [] for batch_index, sample in enumerate(metadata): words.append(sample['text words']) texts.append(sample['text']) targets.append(sample['targets']) target_words.append(sample['target words']) output_dict["words"] = words output_dict["text"] = texts word_attention_weights = attention_weights.view( batch_size, number_targets, context_sequence_length) output_dict["word_attention"] = word_attention_weights output_dict["targets"] = targets output_dict["target words"] = target_words output_dict["context_mask"] = context_mask return output_dict