class TransformerHead(LoggedModule): def __init__(self, config, v_dim, l_dim, loc_dim, backbone): super(TransformerHead, self).__init__() self.config = config.MODEL.MMSS_HEAD.TRANSFORMER self.v_dim = v_dim self.l_dim = l_dim self.loc_dim = loc_dim self.backbone = backbone self.mvm_loss = self.config.MVM_LOSS self.mmm_loss = self.config.MMM_LOSS self.num_negative = self.config.MVM_LOSS_NUM_NEGATIVE self.bert_config = BertConfig(**self.config.BERT_CONFIG) self.v2l_projection = nn.Linear(self.v_dim, self.l_dim) self.visual_emb = VisualEmbedding(self.bert_config, self.l_dim, self.loc_dim) self.encoder = BertEncoder(self.bert_config) self.pooler = BertPooler(self.bert_config) self.heads = MMPreTrainingHeads(self.bert_config, self.v_dim) self.encoder.apply(self._init_weights) self.pooler.apply(self._init_weights) self.heads.apply(self._init_weights) self._tie_weights() self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1) if self.mvm_loss == 'reconstruction_error': self.vis_criterion = nn.MSELoss(reduction="none") elif self.mvm_loss == 'contrastive_cross_entropy': self.vis_criterion = nn.CrossEntropyLoss() elif self.mvm_loss == '': self.vis_criterion = None for p in self.heads.imagePredictions.parameters(): p.requires_grad = False else: raise NotImplementedError if self.mmm_loss == '': for p in self.pooler.parameters(): p.requires_grad = False for p in self.heads.bi_seq_relationship.parameters(): p.requires_grad = False def _tie_weights(self): assert(self.heads.predictions.decoder.weight.shape[0] == self.backbone.embeddings.shape[0]) assert(self.heads.predictions.decoder.weight.shape[1] == self.backbone.embeddings.shape[1]) self.heads.predictions.decoder.weight = self.backbone.embeddings def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range) elif isinstance(module, BertLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def forward(self, input_image, input_caption): ''' Mapping between my terminology and the vilbert codebase: input_ids => Not needed. Instead I have caption_emb which is the output of BERT image_feat => region_features image_loc => region_loc token_type_ids => Not needed. All zero anyway attention_mask => caption_mask, image_attention_mask => region mask, masked_lm_labels => combination of target_caption_ids, mlm_mask, image_label => mvm_mask, image_target => target_region_features, next_sentence_label => Not needed for now (image_caption_match_label). ''' caption_emb = input_caption['encoded_tokens'] caption_mask = input_caption['attention_mask'] mlm_mask = input_caption['mlm_mask'] target_caption_ids = input_caption['target_ids'] region_features = input_image['region_features'] region_mask = input_image['region_mask'] region_loc = input_image['region_loc'] mvm_mask = input_image['mvm_mask'] target_region_features = input_image['target_region_features'] target_caption_ids = torch.where( mlm_mask > 0, target_caption_ids, torch.ones_like(target_caption_ids) * (-1) ) caption_mask = caption_mask.to(torch.float32) region_mask = region_mask.to(torch.float32) mlm_mask = mlm_mask.to(torch.float32) mvm_mask = mvm_mask.to(torch.float32) num_words = caption_mask.sum(dim=1) _, max_num_words = caption_mask.shape batch_size, max_num_regions, _ = region_features.shape image_emb = self.v2l_projection(region_features) image_emb = self.visual_emb(image_emb, region_loc) if self.mmm_loss == 'cross_entropy': image_emb = image_emb[None, :, :, :].repeat(batch_size, 1, 1, 1).reshape( batch_size**2, max_num_regions, self.l_dim) caption_emb = caption_emb[:, None, :, :].repeat(1, batch_size, 1, 1).reshape( batch_size**2, max_num_words, self.l_dim) region_mask = region_mask[None, :, :].repeat(batch_size, 1, 1).reshape( batch_size**2, max_num_regions) caption_mask = caption_mask[:, None, :].repeat(1, batch_size, 1).reshape( batch_size**2, max_num_words) embedded_tokens = torch.cat([caption_emb, image_emb], dim=1) attention_mask = torch.cat([caption_mask, region_mask], dim=1) sequence_output, = self.encoder( embedded_tokens, attention_mask[:, None, None, :], head_mask=[None] * self.bert_config.num_hidden_layers, output_attentions=False, output_hidden_states=False, ) pooled_output = self.pooler(sequence_output) sequence_output_t, sequence_output_v = torch.split( sequence_output, [max_num_words, max_num_regions], dim=1 ) prediction_scores_t, prediction_scores_v, seq_relationship_score = self.heads( sequence_output_t, sequence_output_v, pooled_output ) # prediction_scores_v = prediction_scores_v[:, 1:] if self.mmm_loss == 'cross_entropy': prediction_scores_t = torch.diagonal(prediction_scores_t.reshape( batch_size, batch_size, max_num_words, self.bert_config.vocab_size), dim1=0, dim2=1).permute(2, 0, 1) prediction_scores_v = torch.diagonal(prediction_scores_v.reshape( batch_size, batch_size, max_num_regions, self.v_dim), dim1=0, dim2=1).permute(2, 0, 1) masked_lm_loss = self.loss_fct( prediction_scores_t.reshape(-1, self.bert_config.vocab_size), target_caption_ids.reshape(-1), ) if self.mmm_loss == 'binary': raise NotImplementedError next_sentence_loss = self.loss_fct( seq_relationship_score.view(-1, 2), image_caption_match_label.view(-1) ) elif self.mmm_loss == 'cross_entropy': global_dist = seq_relationship_score[:, 0] pw_cost = global_dist.reshape(batch_size, batch_size) pw_logits_c_cap = torch.log_softmax(- pw_cost, dim=0) pw_logits_c_img = torch.log_softmax(- pw_cost, dim=1) next_sentence_loss_c_cap = torch.diag(- pw_logits_c_cap).mean() next_sentence_loss_c_img = torch.diag(- pw_logits_c_img).mean() next_sentence_loss = next_sentence_loss_c_cap + next_sentence_loss_c_img elif self.mmm_loss == '': next_sentence_loss = torch.tensor(0.0).cuda() else: raise NotImplementedError if self.mvm_loss == 'reconstruction_error': raise NotImplementedError img_loss = self.vis_criterion(prediction_scores_v, target_region_features) masked_img_loss = torch.sum( img_loss * (mvm_mask == 1).unsqueeze(2).float() ) / max( torch.sum((mvm_mask == 1).unsqueeze(2).expand_as(img_loss)), 1 ) elif self.mvm_loss == 'contrastive_cross_entropy': raise NotImplementedError # generate negative sampled index. num_negative = self.num_negative num_across_batch = int(self.num_negative * 0.7) num_inside_batch = int(self.num_negative * 0.3) # random negative across batches. row_across_index = target_caption_ids.new( batch_size, max_num_regions, num_across_batch ).random_(0, batch_size - 1) col_across_index = target_caption_ids.new( batch_size, max_num_regions, num_across_batch ).random_(0, max_num_regions) for i in range(batch_size - 1): row_across_index[i][row_across_index[i] == i] = batch_size - 1 final_across_index = row_across_index * max_num_regions + col_across_index # random negative inside batches. row_inside_index = target_caption_ids.new( batch_size, max_num_regions, num_inside_batch ).zero_() col_inside_index = target_caption_ids.new( batch_size, max_num_regions, num_inside_batch ).random_(0, max_num_regions - 1) for i in range(batch_size): row_inside_index[i] = i for i in range(max_num_regions - 1): col_inside_index[:, i, :][col_inside_index[:, i, :] == i] = ( max_num_regions - 1 ) final_inside_index = row_inside_index * max_num_regions + col_inside_index final_index = torch.cat((final_across_index, final_inside_index), dim=2) # Let's first sample where we need to compute. predict_v = prediction_scores_v[mvm_mask == 1] neg_index_v = final_index[mvm_mask == 1] flat_image_target = target_region_features.view(batch_size * max_num_regions, -1) # we also need to append the target feature at the begining. negative_v = flat_image_target[neg_index_v] positive_v = target_region_features[mvm_mask == 1] sample_v = torch.cat((positive_v.unsqueeze(1), negative_v), dim=1) # calculate the loss. score = torch.bmm(sample_v, predict_v.unsqueeze(2)).squeeze(2) masked_img_loss = self.vis_criterion( score, target_caption_ids.new(score.size(0)).zero_() ) elif self.mvm_loss == '': masked_img_loss = torch.tensor(0.0).cuda() else: raise NotImplementedError # masked_img_loss = torch.sum(img_loss) / (img_loss.shape[0] * img_loss.shape[1]) losses = { 'Masked Language Modeling Loss': masked_lm_loss, 'Masked Visual Modeling Loss': masked_img_loss, 'Image Caption Matching Loss': next_sentence_loss, } acc_num = (prediction_scores_t.argmax(dim=-1) == target_caption_ids).to(torch.float32).sum() acc_denom = (target_caption_ids >= 0).to(torch.float32).sum() acc = torch.where(acc_denom > 0, acc_num / acc_denom, acc_denom) other_info = { 'Masked Language Modeling Accuracy': acc, } if self.mmm_loss == 'cross_entropy': other_info['Batch Accuracy (Choose Caption)'] = torch.mean( (pw_cost.argmin(dim=0) ==torch.arange(batch_size).to('cuda')).to(torch.float32)) other_info['Batch Accuracy (Choose Image)'] = torch.mean( (pw_cost.argmin(dim=1) == torch.arange(batch_size).to('cuda')).to(torch.float32)) self.log_dict(losses) self.log_dict(other_info) return other_info, losses
class BiaffineDependencyS2SQueryParser(torch.nn.Module): """ This dependency parser follows the model of [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)] (https://arxiv.org/abs/1611.01734) . But We use token-to-token MRC to extract parent and labels Args: config: S2SConfig that defines model dim and structure bert_dir: pretrained bert directory """ # def __init__(self, config: Union[BertMrcS2SDependencyConfig, RobertaMrcS2SDependencyConfig]): def __init__(self, config: S2SConfig, bert_dir: str = ""): # super().__init__(config) super().__init__() self.config = config num_dep_labels = len(config.dep_tags) num_pos_labels = len(config.pos_tags) hidden_size = config.additional_layer_dim if config.pos_dim > 0: self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim) nn.init.xavier_uniform_(self.pos_embedding.weight) if config.additional_layer_type != "lstm" and config.pos_dim + config.bert_config.hidden_size != hidden_size: self.fuse_layer = nn.Linear( config.pos_dim + config.bert_config.hidden_size, hidden_size) nn.init.xavier_uniform_(self.fuse_layer.weight) self.fuse_layer.bias.data.zero_() else: self.fuse_layer = None else: self.pos_embedding = None # if isinstance(config, BertMrcS2SDependencyConfig): # self.bert = BertModel(config) # self.arch = "bert" # else: # self.roberta = RobertaModel(config) # self.arch = "roberta" self.bert = AutoModel.from_pretrained( pretrained_model_name_or_path=bert_dir, config=config.bert_config) if config.additional_layer > 0: if config.additional_layer_type == "transformer": new_config = deepcopy(config.bert_config) new_config.hidden_size = hidden_size new_config.num_hidden_layers = config.additional_layer new_config.hidden_dropout_prob = new_config.attention_probs_dropout_prob = config.mrc_dropout # new_config.attention_probs_dropout_prob = config.biaf_dropout # todo add to hparams and tune self.additional_encoder = BertEncoder(new_config) self.additional_encoder.apply(self._init_bert_weights) else: assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size" self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder( input_size=config.pos_dim + config.bert_config.hidden_size, hidden_size=hidden_size // 2, num_layers=config.additional_layer, recurrent_dropout_probability=config.mrc_dropout, use_highway=True) else: self.additional_encoder = None # todo use MLP self.parent_feedforward = nn.Linear(hidden_size, 1) self.parent_start_feedforward = nn.Linear(hidden_size, 1) self.parent_end_feedforward = nn.Linear(hidden_size, 1) self.parent_tag_feedforward = nn.Linear(hidden_size, num_dep_labels) if config.predict_child: self.child_feedforward = nn.Linear(hidden_size, 1) self.child_start_feedforward = nn.Linear(hidden_size, 1) self.child_end_feedforward = nn.Linear(hidden_size, 1) # self._dropout = nn.Dropout(config.mrc_dropout) self._dropout = InputVariationalDropout(config.mrc_dropout) # init linear children for layer in self.children(): if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: layer.bias.data.zero_() def _init_bert_weights(self, module): """ Initialize the weights. copy from transformers.BertPreTrainedModel""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_( mean=0.0, std=self.config.bert_config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @overrides def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, parent_mask: torch.BoolTensor, parent_start_mask: torch.BoolTensor, parent_end_mask: torch.BoolTensor, child_mask: torch.BoolTensor = None, parent_idxs: torch.LongTensor = None, parent_tags: torch.LongTensor = None, parent_starts: torch.BoolTensor = None, parent_ends: torch.BoolTensor = None, child_idxs: torch.BoolTensor = None, child_starts: torch.BoolTensor = None, child_ends: torch.BoolTensor = None, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] parent_mask: [batch_size, num_words] parent_start_mask: [batch_size, num_words] parent_end_mask: [batch_size, num_words] child_mask: [batch_size, num_words] parent_idxs: [batch_size] parent_tags: [batch_size] parent_starts: [batch_size] parent_ends: [batch_size] child_idxs: [batch_size, num_words] child_starts: [batch_size, num_words] child_ends: [batch_size, num_words] Returns: parent_probs: [batch_size, num_words] parent_tag_probs: [batch_size, num_words, num_tags] parent_start_probs: [batch_size, num_words] parent_end_probs: [batch_size, num_words] child_probs: [batch_size, num_words] child_start_probs: [batch_size, num_words] child_end_probs: [batch_size, num_words] arc_loss (if parent_idx is not None) tag_loss (if parent_idxs and parent_tags are not None) start_loss (if parent_starts is not None) end_loss (if parent_ends is not None) child_loss (if child_idxs is not None) child_start_loss (if child_starts is not None) child_end_loss (if child_ends is not None) """ cls_embedding, embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": # bert = self.bert if self.arch == "bert" else self.roberta extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # shape (batch_size, sequence_length, tag_classes) parent_tag_scores = self.parent_tag_feedforward(encoded_text) # shape (batch_size, sequence_length) parent_scores = self.parent_feedforward(encoded_text).squeeze(-1) parent_start_scores = self.parent_start_feedforward( encoded_text).squeeze(-1) parent_end_scores = self.parent_end_feedforward(encoded_text).squeeze( -1) # mask out impossible positions minus_inf = -1e8 parent_mask = torch.logical_and(parent_mask, word_mask) parent_scores = parent_scores + (~parent_mask).float() * minus_inf parent_start_mask = torch.logical_and(parent_start_mask, word_mask) parent_start_scores = parent_start_scores + ( ~parent_start_mask).float() * minus_inf parent_end_mask = torch.logical_and(parent_end_mask, word_mask) parent_end_scores = parent_end_scores + ( ~parent_end_mask).float() * minus_inf parent_probs = F.softmax(parent_scores, dim=-1) parent_start_probs = F.softmax(parent_start_scores, dim=-1) parent_end_probs = F.softmax(parent_end_scores, dim=-1) parent_tag_probs = F.softmax(parent_tag_scores, dim=-1) output = (parent_probs, parent_tag_probs, parent_start_probs, parent_end_probs) if self.config.predict_child: child_scores = self.child_feedforward(encoded_text).squeeze(-1) child_start_scores = self.child_start_feedforward( encoded_text).squeeze(-1) child_end_scores = self.child_end_feedforward( encoded_text).squeeze(-1) # todo add child mask - child should be inside the origin span if child_mask is None: child_mask = torch.ones_like(word_mask) else: child_mask = torch.logical_and(child_mask, word_mask) child_scores = child_scores + (~child_mask).float() * minus_inf child_start_scores = child_start_scores + ( ~child_mask).float() * minus_inf child_end_scores = child_end_scores + ( ~child_mask).float() * minus_inf child_probs = torch.sigmoid(child_scores) child_start_probs = torch.sigmoid(child_start_scores) child_end_probs = torch.sigmoid(child_end_scores) output = output + (child_probs, child_start_probs, child_end_probs) # add losses batch_range_vector = get_range_vector( batch_size, get_device_of(encoded_text)) # [bsz] if parent_idxs is not None: # [bsz, seq_len] parent_logits = F.log_softmax(parent_scores, dim=-1) parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs] parent_arc_nll = parent_arc_nll.mean() output = output + (parent_arc_nll, ) if parent_tags is not None: parent_tag_nll = F.cross_entropy( parent_tag_scores[batch_range_vector, parent_idxs], parent_tags) output = output + (parent_tag_nll, ) if parent_starts is not None: # [bsz, seq_len] parent_start_logits = F.log_softmax(parent_start_scores, dim=-1) parent_start_nll = -parent_start_logits[batch_range_vector, parent_starts].mean() output = output + (parent_start_nll, ) if parent_ends is not None: # [bsz, seq_len] parent_end_logits = F.log_softmax(parent_end_scores, dim=-1) parent_end_nll = -parent_end_logits[batch_range_vector, parent_ends].mean() output = output + (parent_end_nll, ) if self.config.predict_child: if child_idxs is not None: child_loss = F.binary_cross_entropy_with_logits( child_scores, child_idxs.float(), reduction="none") child_loss = (child_loss * child_mask).sum() / (child_mask.sum() + 1e-8) output = output + (child_loss, ) if child_starts is not None: child_start_loss = F.binary_cross_entropy_with_logits( child_start_scores, child_starts.float(), reduction="none") child_start_loss = (child_start_loss * child_mask).sum() / ( child_mask.sum() + 1e-8) output = output + (child_start_loss, ) if child_ends is not None: child_end_loss = F.binary_cross_entropy_with_logits( child_end_scores, child_ends.float(), reduction="none") child_end_loss = (child_end_loss * child_mask).sum() / (child_mask.sum() + 1e-8) output = output + (child_end_loss, ) return output def get_word_embedding( self, token_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # type: ignore """get [CLS] embedding and word-level embedding""" # Shape: [batch_size, num_wordpieces, embedding_size]. # embed_model = self.bert if self.arch == "bert" else self.roberta # embeddings = embed_model(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0] embeddings = self.bert(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0] # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = allennlp_util.batched_span_select( embeddings, offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min( span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand( orig_embeddings.shape)] = 0 return embeddings[:, 0, :], orig_embeddings
class BiaffineDependencyT2TParser(nn.Module): """ This dependency parser follows the model of [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)] (https://arxiv.org/abs/1611.01734) . But We use token-to-token MRC to extract parent and labels """ def __init__(self, bert_dir, config): super().__init__() self.config = config num_dep_labels = len(config.dep_tags) num_pos_labels = len(config.pos_tags) hidden_size = config.additional_layer_dim if config.pos_dim > 0: self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim) nn.init.xavier_uniform_(self.pos_embedding.weight) if config.additional_layer_type != "lstm" and config.pos_dim + config.hidden_size != hidden_size: self.fuse_layer = nn.Linear( config.pos_dim + config.hidden_size, hidden_size) nn.init.xavier_uniform_(self.fuse_layer.weight) self.fuse_layer.bias.data.zero_() else: self.fuse_layer = None else: self.pos_embedding = None if isinstance(config, BertMrcT2TDependencyConfig): self.bert = BertModel.from_pretrained(bert_dir, config=self.config) elif isinstance(config, RobertaMrcT2TDependencyConfig): self.bert = RobertaModel.from_pretrained(bert_dir, config=self.config) if config.additional_layer > 0: if config.additional_layer_type == "transformer": new_config = deepcopy(config) new_config.hidden_size = hidden_size new_config.num_hidden_layers = config.additional_layer new_config.hidden_dropout_prob = new_config.attention_probs_dropout_prob = config.mrc_dropout # new_config.attention_probs_dropout_prob = config.biaf_dropout # todo add to hparams and tune self.additional_encoder = BertEncoder(new_config) self.additional_encoder.apply(self._init_bert_weights) else: assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size" self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder( input_size=config.pos_dim + config.hidden_size, hidden_size=hidden_size // 2, num_layers=config.additional_layer, recurrent_dropout_probability=config.mrc_dropout, use_highway=True) else: self.additional_encoder = None self.parent_feedforward = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.GELU(), InputVariationalDropout(config.mrc_dropout), nn.Linear(hidden_size // 2, 1), ) self.parent_tag_feedforward = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.GELU(), InputVariationalDropout(config.mrc_dropout), nn.Linear(hidden_size // 2, num_dep_labels), ) self.child_feedforward = deepcopy(self.parent_feedforward) self.child_tag_feedforward = deepcopy(self.parent_tag_feedforward) # self.mrc_dropout = nn.Dropout(config.mrc_dropout) self._dropout = InputVariationalDropout(config.mrc_dropout) # todo renit feedforward? def _init_bert_weights(self, module): """ Initialize the weights. copy from transformers.BertPreTrainedModel""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @overrides def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, span_idx: torch.LongTensor, span_tag: torch.LongTensor, child_arcs: torch.LongTensor, child_tags: torch.LongTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, mrc_mask: torch.BoolTensor, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] span_idx: [batch_size, 2] span_tag: [batch_size, 1] child_arcs: [batch_size, num_words] child_tags: [batch_size, num_words] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] mrc_mask: [batch_size, num_words] Returns: parent_probs: [batch_size, num_word] parent_tag_probs: [batch_size, num_words] arc_nll: [1] tag_nll: [1] """ embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # shape (batch_size, sequence_length, tag_classes) parent_tag_scores = self.parent_tag_feedforward(encoded_text) # shape (batch_size, sequence_length) parent_scores = self.parent_feedforward(encoded_text).squeeze(-1) # [bsz, seq_len, tag_classes] child_tag_scores = self.child_tag_feedforward(encoded_text) # [bsz, seq_len] child_scores = self.child_feedforward(encoded_text).squeeze(-1) # todo support cases that span_idx and span_tag are None # [bsz] batch_range_vector = get_range_vector(batch_size, get_device_of(encoded_text)) # [bsz] gold_positions = span_idx[:, 0] # compute parent arc loss minus_inf = -1e8 mrc_mask = torch.logical_and(mrc_mask, word_mask) parent_scores = parent_scores + (~mrc_mask).float() * minus_inf child_scores = child_scores + (~mrc_mask).float() * minus_inf # [bsz, seq_len] parent_logits = F.log_softmax(parent_scores, dim=-1) parent_arc_nll = -parent_logits[batch_range_vector, gold_positions].mean() # compute parent tag loss parent_tag_nll = F.cross_entropy( parent_tag_scores[batch_range_vector, gold_positions], span_tag) parent_probs = F.softmax(parent_scores, dim=-1) parent_tag_probs = F.softmax(parent_tag_scores, dim=-1) child_probs = F.sigmoid(child_scores) child_tag_probs = F.softmax(child_tag_scores, dim=-1) child_arc_loss = F.binary_cross_entropy_with_logits(child_scores, child_arcs.float(), reduction="none") child_arc_loss = (child_arc_loss * mrc_mask.float()).sum() / mrc_mask.float().sum() child_tag_loss = F.cross_entropy(child_tag_scores.view( batch_size * seq_len, -1), child_tags.view(-1), reduction="none") child_tag_loss = (child_tag_loss * child_arcs.float().view(-1) ).sum() / (child_arcs.float().sum() + 1e-8) return parent_probs, parent_tag_probs, child_probs, child_tag_probs, parent_arc_nll, parent_tag_nll, child_arc_loss, child_tag_loss def get_word_embedding( self, token_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # type: ignore """get word-level embedding""" # Shape: [batch_size, num_wordpieces, embedding_size]. embeddings = self.bert(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0] # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = allennlp_util.batched_span_select( embeddings, offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min( span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand( orig_embeddings.shape)] = 0 return orig_embeddings
class SpanProposal(torch.nn.Module): """ This model is used to extract candidate start/end subtree span rooted at each token. Args: config: SpanProposal Config that defines model dim and structure bert_dir: pretrained bert directory """ def __init__(self, config: SpanProposalConfig, bert_dir: str = ""): super().__init__() self.config = config num_pos_labels = len(config.pos_tags) hidden_size = config.additional_layer_dim if config.additional_layer > 0 else config.pos_dim + config.bert_config.hidden_size self.bert = AutoModel.from_pretrained( pretrained_model_name_or_path=bert_dir, config=config.bert_config) if config.pos_dim > 0: self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim) nn.init.xavier_uniform_(self.pos_embedding.weight) if (config.additional_layer and config.additional_layer_type != "lstm" and config.pos_dim + config.bert_config.hidden_size != hidden_size): self.fuse_layer = nn.Linear( config.pos_dim + config.bert_config.hidden_size, hidden_size) nn.init.xavier_uniform_(self.fuse_layer.weight) self.fuse_layer.bias.data.zero_() else: self.fuse_layer = None else: self.pos_embedding = None if config.additional_layer > 0: if config.additional_layer_type == "transformer": new_config = deepcopy(config.bert_config) new_config.hidden_size = hidden_size new_config.num_hidden_layers = config.additional_layer new_config.hidden_dropout_prob = new_config.attention_probs_dropout_prob = config.mrc_dropout # new_config.attention_probs_dropout_prob = config.biaf_dropout # todo add to hparams and tune self.additional_encoder = BertEncoder(new_config) self.additional_encoder.apply(self._init_bert_weights) else: assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size" self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder( input_size=config.pos_dim + config.bert_config.hidden_size, hidden_size=hidden_size // 2, num_layers=config.additional_layer, recurrent_dropout_probability=config.mrc_dropout, use_highway=True) else: self.additional_encoder = None self._dropout = InputVariationalDropout(config.mrc_dropout) self.subtree_start_feedforward = FeedForward( hidden_size, 1, config.arc_representation_dim, Activation.by_name("elu")()) self.subtree_end_feedforward = deepcopy(self.subtree_start_feedforward) # todo: equivalent to self-attention? self.subtree_start_attention = BilinearMatrixAttention( config.arc_representation_dim, config.arc_representation_dim, use_input_biases=True) self.subtree_end_attention = deepcopy(self.subtree_start_attention) # init linear children for layer in self.children(): if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: layer.bias.data.zero_() def _init_bert_weights(self, module): """ Initialize the weights. copy from transformers.BertPreTrainedModel""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_( mean=0.0, std=self.config.bert_config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @overrides def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, subtree_spans: torch.LongTensor = None, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] subtree_spans: [batch_size, num_words, 2] Returns: span_start_logits: [batch_size, num_words, num_words] span_end_logits: [batch_size, num_words, num_words] span_loss: if subtree_spans is given. """ # [bsz, seq_len, hidden] embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # [bsz, seq_len, dim] subtree_start_representation = self._dropout( self.subtree_start_feedforward(encoded_text)) subtree_end_representation = self._dropout( self.subtree_end_feedforward(encoded_text)) # [bsz, seq_len, seq_len] span_start_scores = self.subtree_start_attention( subtree_start_representation, subtree_start_representation) span_end_scores = self.subtree_end_attention( subtree_end_representation, subtree_end_representation) # start of word should less equal to it start_mask = word_mask.unsqueeze(-1) & ( ~torch.triu(span_start_scores.bool(), 1)) # end of word should greater equal to it. end_mask = word_mask.unsqueeze(-1) & torch.triu(span_end_scores.bool()) minus_inf = -1e8 span_start_scores = span_start_scores + ( ~start_mask).float() * minus_inf span_end_scores = span_end_scores + (~end_mask).float() * minus_inf output = (F.log_softmax(span_start_scores, dim=-1), F.log_softmax(span_end_scores, dim=-1)) if subtree_spans is not None: start_loss = F.cross_entropy( span_start_scores.view(batch_size * seq_len, -1), subtree_spans[:, :, 0].view(-1)) end_loss = F.cross_entropy( span_end_scores.view(batch_size * seq_len, -1), subtree_spans[:, :, 1].view(-1)) span_loss = start_loss + end_loss output = output + (span_loss, ) return output def get_word_embedding( self, token_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # type: ignore """ get word-level embedding """ # Shape: [batch_size, num_wordpieces, embedding_size]. embeddings = self.bert(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0] # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = allennlp_util.batched_span_select( embeddings, offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min( span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand( orig_embeddings.shape)] = 0 return orig_embeddings
class BiaffineDependencyParser(nn.Module): """ This dependency parser follows the model of [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)] (https://arxiv.org/abs/1611.01734) . But we use BERT for embedding """ def __init__(self, bert_dir, config): super().__init__() self.config = config num_dep_labels = len(config.dep_tags) num_pos_labels = len(config.pos_tags) hidden_size = config.additional_layer_dim if config.pos_dim > 0: self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim) nn.init.xavier_uniform_(self.pos_embedding.weight) if config.additional_layer_type != "lstm" and config.pos_dim + config.hidden_size != hidden_size: self.fuse_layer = nn.Linear( config.pos_dim + config.hidden_size, hidden_size) nn.init.xavier_uniform_(self.fuse_layer.weight) self.fuse_layer.bias.data.zero_() else: self.fuse_layer = None else: self.pos_embedding = None if isinstance(config, BertDependencyConfig): self.bert = BertModel.from_pretrained(bert_dir, config=self.config) elif isinstance(config, RobertaDependencyConfig): self.bert = RobertaModel.from_pretrained(bert_dir, config=self.config) if config.additional_layer > 0: if config.additional_layer_type == "transformer": new_config = deepcopy(config) new_config.hidden_size = hidden_size new_config.num_hidden_layers = config.additional_layer new_config.hidden_dropout_prob = config.biaf_dropout new_config.attention_probs_dropout_prob = config.biaf_dropout # todo add to hparams and tune self.additional_encoder = BertEncoder(new_config) self.additional_encoder.apply(self._init_bert_weights) else: assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size" self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder( input_size=config.pos_dim + config.hidden_size, hidden_size=hidden_size // 2, num_layers=config.additional_layer, recurrent_dropout_probability=config.biaf_dropout, use_highway=True) else: self.additional_encoder = None self.head_arc_feedforward = FeedForward(hidden_size, 1, config.arc_representation_dim, Activation.by_name("elu")()) self.child_arc_feedforward = deepcopy(self.head_arc_feedforward) self.arc_attention = BilinearMatrixAttention( config.arc_representation_dim, config.arc_representation_dim, use_input_biases=True) self.head_tag_feedforward = FeedForward(hidden_size, 1, config.tag_representation_dim, Activation.by_name("elu")()) self.child_tag_feedforward = deepcopy(self.head_tag_feedforward) self.tag_bilinear = nn.modules.Bilinear(config.tag_representation_dim, config.tag_representation_dim, num_dep_labels) nn.init.xavier_uniform_(self.tag_bilinear.weight) self.tag_bilinear.bias.data.zero_() self._dropout = InputVariationalDropout(config.biaf_dropout) self._input_dropout = nn.Dropout(config.biaf_dropout) self._head_sentinel = torch.nn.Parameter( torch.randn([1, 1, hidden_size])) def _init_bert_weights(self, module): """ Initialize the weights. copy from transformers.BertPreTrainedModel""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @overrides def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, dep_idxs: torch.LongTensor, dep_tags: torch.LongTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, ): embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._input_dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) word_mask = torch.cat([word_mask.new_ones(batch_size, 1), word_mask], 1) dep_idxs = torch.cat([dep_idxs.new_zeros(batch_size, 1), dep_idxs], 1) dep_tags = torch.cat([dep_tags.new_zeros(batch_size, 1), dep_tags], 1) encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = ~word_mask * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, word_mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, word_mask) arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=dep_idxs, head_tags=dep_tags, mask=word_mask, ) return predicted_heads, predicted_head_tags, arc_nll, tag_nll def get_word_embedding( self, token_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # type: ignore """get word-level embedding""" # Shape: [batch_size, num_wordpieces, embedding_size]. embeddings = self.bert(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0] # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = allennlp_util.batched_span_select( embeddings, offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min( span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand( orig_embeddings.shape)] = 0 return orig_embeddings def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.BoolTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : `torch.BoolTensor`, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. # Returns arc_nll : `torch.Tensor`, required. The negative log likelihood from the arc loss. tag_nll : `torch.Tensor`, required. The negative log likelihood from the arc tag loss. """ batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = (masked_log_softmax(attended_arcs, mask) * mask.unsqueeze(2) * mask.unsqueeze(1)) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax( head_tag_logits, mask.unsqueeze(-1)) * mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = (timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long()) # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.BoolTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. # Returns heads : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag( attended_arcs.new(mask.size(1)).fill_(-np.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = ~mask.unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -np.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, heads) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _mst_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.BoolTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions using the Edmonds' Algorithm for finding minimum spanning trees on directed graphs. Nodes in the graph are the words in the sentence, and between each pair of nodes, there is an edge in each direction, where the weight of the edge corresponds to the most likely dependency label probability for that arc. The MST is then generated from this directed graph. # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. # Returns heads : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the optimally decoded heads of each word. """ batch_size, sequence_length, tag_representation_dim = head_tag_representation.size( ) lengths = mask.data.sum(dim=1).long().cpu().numpy() expanded_shape = [ batch_size, sequence_length, sequence_length, tag_representation_dim ] head_tag_representation = head_tag_representation.unsqueeze(2) head_tag_representation = head_tag_representation.expand( *expanded_shape).contiguous() child_tag_representation = child_tag_representation.unsqueeze(1) child_tag_representation = child_tag_representation.expand( *expanded_shape).contiguous() # Shape (batch_size, sequence_length, sequence_length, num_head_tags) pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) # Note that this log_softmax is over the tag dimension, and we don't consider pairs # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. # Shape (batch, num_labels,sequence_length, sequence_length) normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute( 0, 3, 1, 2) # Mask padded tokens, because we only want to consider actual words as heads. minus_inf = -1e8 minus_mask = ~mask * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) # Shape (batch_size, sequence_length, sequence_length) normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) # Shape (batch_size, num_head_tags, sequence_length, sequence_length) # This energy tensor expresses the following relation: # energy[i,j] = "Score that i is the head of j". In this # case, we have heads pointing to their children. batch_energy = torch.exp( normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) return self._run_mst_decoding(batch_energy, lengths) @staticmethod def _run_mst_decoding( batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: heads = [] head_tags = [] for energy, length in zip(batch_energy.detach().cpu(), lengths): scores, tag_ids = energy.max(dim=0) # Although we need to include the root node so that the MST includes it, # we do not want any word to be the parent of the root node. # Here, we enforce this by setting the scores for all word -> ROOT edges # edges to be 0. scores[0, :] = 0 # Decode the heads. Because we modify the scores to prevent # adding in word -> ROOT edges, we need to find the labels ourselves. instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False) # Find the labels which correspond to the edges in the max spanning tree. instance_head_tags = [] for child, parent in enumerate(instance_heads): instance_head_tags.append(tag_ids[parent, child].item()) # We don't care what the head or tag is for the root token, but by default it's # not necessarily the same in the batched vs unbatched case, which is annoying. # Here we'll just set them to zero. instance_heads[0] = 0 instance_head_tags[0] = 0 heads.append(instance_heads) head_tags.append(instance_head_tags) return ( torch.from_numpy(np.stack(heads)).to(batch_energy.device), torch.from_numpy(np.stack(head_tags)).to(batch_energy.device), ) def _get_head_tags( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor, ) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. # Returns head_tag_logits : `torch.Tensor` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/np-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[ range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous( ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits