def get_srl_loss_mask(self, srl_scores, num_predicted_args, num_predicted_preds): max_num_arg = srl_scores.size()[1] max_num_pred = srl_scores.size()[2] # num_predicted_args, 1D tensor; max_num_arg: a int variable means the gold ans's max arg number args_mask = util.lengths_to_mask(num_predicted_args, max_num_arg) pred_mask = util.lengths_to_mask(num_predicted_preds, max_num_pred) srl_loss_mask = args_mask.unsqueeze(2) & pred_mask.unsqueeze(1) return srl_loss_mask
def compute_loss(self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion): mask = util.lengths_to_mask(batch['text_input_ids_length']) return criterion( output[mask], batch['text_prefix_mask'][:, 1:-1][mask].to(torch.float))
def feed_batch(self, batch: Dict[str, Any], task_name, output_dict=None, run_transform=False, cls_is_bos=False, sep_is_eos=False, results=None) -> Tuple[Dict[str, Any], Dict[str, Any]]: h, output_dict = self._encode(batch, task_name, output_dict, cls_is_bos, sep_is_eos) task = self.tasks[task_name] if run_transform: batch = task.transform_batch(batch, results=results, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos) batch['mask'] = mask = util.lengths_to_mask(batch['token_length']) output_dict[task_name] = { 'output': task.feed_batch(h, batch=batch, mask=mask, decoder=self.model.decoders[task_name]), 'mask': mask } return output_dict, batch
def decode(self, contextualized_embeddings, gold_starts, gold_ends, gold_labels, masks, max_sent_length, num_sentences, sent_lengths): # Apply MLPs to starts and ends, [num_sentences, max_sentences_length,emb] candidate_starts_emb = self.start_mlp(contextualized_embeddings) candidate_ends_emb = self.end_mlp(contextualized_embeddings) candidate_ner_scores = self.biaffine(candidate_starts_emb, candidate_ends_emb).permute([0, 2, 3, 1]) """generate candidate spans with argument pruning""" # Generate masks candidate_scores_mask = masks.unsqueeze(1) & masks.unsqueeze(2) device = sent_lengths.device sentence_ends_leq_starts = ( ~util.lengths_to_mask(torch.arange(max_sent_length, device=device), max_sent_length)) \ .unsqueeze_(0).expand(num_sentences, -1, -1) candidate_scores_mask &= sentence_ends_leq_starts candidate_ner_scores = candidate_ner_scores[candidate_scores_mask] predict_dict = { "candidate_ner_scores": candidate_ner_scores, } if gold_starts is not None: gold_ner_labels = self.get_dense_span_labels(gold_starts, gold_ends, gold_labels, max_sent_length) loss = torch.nn.functional.cross_entropy(candidate_ner_scores, gold_ner_labels[candidate_scores_mask], reduction=self.loss_reduction) predict_dict['loss'] = loss return predict_dict
def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None): if mask is None: mask = util.lengths_to_mask(batch['token_length']) else: mask = mask.clone() scores = self.decoder(contextualized_embeddings, mask) mask[:, 0] = 0 return scores, mask
def unpack(batch, mask=None, training=False): keys = 'token_length', 'predicate_offset', 'argument_begin_offset', 'argument_end_offset', 'srl_label_id' sent_lengths, gold_predicates, gold_arg_starts, gold_arg_ends, gold_arg_labels = [batch.get(k, None) for k in keys] if mask is None: mask = util.lengths_to_mask(sent_lengths) # elif not training: # sent_lengths = mask.sum(dim=1) return gold_arg_ends, gold_arg_labels, gold_arg_starts, gold_predicates, mask, sent_lengths
def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None): keys = 'token_length', 'begin_offset', 'end_offset', 'label_id' sent_lengths, gold_starts, gold_ends, gold_labels = [batch.get(k, None) for k in keys] if mask is None: mask = util.lengths_to_mask(sent_lengths) num_sentences, max_sent_length = mask.size() return self.decode(contextualized_embeddings, gold_starts, gold_ends, gold_labels, mask, max_sent_length, num_sentences, sent_lengths)
def feed_batch(self, batch): words, feats, lens, puncts = batch.get('token_id', None), batch.get('pos_id', None), batch['sent_length'], \ batch.get('punct_mask', None) mask = lengths_to_mask(lens) logits = self.model(batch, mask) if self.model.training: mask = mask.clone() # ignore the first token of each sentence mask[:, 0] = 0 return logits, mask, puncts
def feed_batch(self, batch: dict): lens = batch['token_length'] mask2d = lengths_to_mask(lens) pred = self.model(batch, mask=mask2d) mask3d = self.compute_mask(mask2d) if self.config.crf: token_index = mask3d[0] pred = pred.flatten(end_dim=1)[token_index] pred = F.log_softmax(pred, dim=-1) return pred, mask3d
def pool_subtoken(last_hidden_state, token_span, concept_mask, concept_lens, max_concept_len): last_hidden_state = pick_tensor_for_each_token(last_hidden_state, token_span, True) concept_hidden = torch.zeros(( concept_mask.size(0), max_concept_len, last_hidden_state.size(-1), ), device=concept_mask.device) concept_hidden[lengths_to_mask( concept_lens, max_concept_len)] = last_hidden_state[concept_mask] return concept_hidden
def forward(self, lens: torch.LongTensor, input_ids, token_span, token_type_ids=None): mask = lengths_to_mask(lens) x = self.encoder(input_ids, token_span=token_span, token_type_ids=token_type_ids) if self.secondary_encoder: x = self.secondary_encoder(x, mask=mask) x = self.classifier(x) return x, mask
def get_span_emb(self, flatted_context_emb, flatted_candidate_starts, flatted_candidate_ends, config, dropout=0.0): batch_word_num = flatted_context_emb.size()[0] # gather slices from embeddings according to indices span_start_emb = flatted_context_emb[flatted_candidate_starts] span_end_emb = flatted_context_emb[flatted_candidate_ends] span_emb_feature_list = [ span_start_emb, span_end_emb ] # store the span vector representations for span rep. span_width = 1 + flatted_candidate_ends - flatted_candidate_starts # [num_spans], generate the span width max_arg_width = config.max_arg_width # get the span width feature emb span_width_index = span_width - 1 span_width_emb = self.span_width_embedding(span_width_index) span_width_emb = F.dropout(span_width_emb, dropout, self.training) span_emb_feature_list.append(span_width_emb) """head features""" cpu_flatted_candidte_starts = flatted_candidate_starts span_indices = torch.arange(0, max_arg_width, device=flatted_context_emb.device).view(1, -1) + \ cpu_flatted_candidte_starts.view(-1, 1) # For all the i, where i in [begin, ..i, end] for span # reset the position index to the batch_word_num index with index - 1 span_indices = torch.clamp(span_indices, max=batch_word_num - 1) num_spans, spans_width = span_indices.size()[0], span_indices.size()[1] flatted_span_indices = span_indices.view( -1) # so Huge!!!, column is the span? # if torch.cuda.is_available(): flatted_span_indices = flatted_span_indices span_text_emb = flatted_context_emb.index_select( 0, flatted_span_indices).view(num_spans, spans_width, -1) span_indices_mask = util.lengths_to_mask(span_width, max_len=max_arg_width) # project context output to num head # head_scores = self.context_projective_layer.forward(flatted_context_emb) # get span attention # span_attention = head_scores.index_select(0, flatted_span_indices).view(num_spans, spans_width) # span_attention = torch.add(span_attention, expanded_span_indices_log_mask).unsqueeze(2) # control the span len # span_attention = F.softmax(span_attention, dim=1) span_text_emb = span_text_emb * span_indices_mask.unsqueeze(2).expand( -1, -1, span_text_emb.size()[-1]) span_head_emb = torch.mean(span_text_emb, 1) span_emb_feature_list.append(span_head_emb) span_emb = torch.cat(span_emb_feature_list, 1) return span_emb, None, span_text_emb, span_indices, span_indices_mask
def feed_batch(self, batch): words, feats, lens, puncts = batch.get('token_id', None), batch.get('pos_id', None), batch['sent_length'], \ batch.get('punct_mask', None) mask = lengths_to_mask(lens) arc_scores, rel_scores = self.model(words=words, feats=feats, mask=mask, batch=batch, **batch) # ignore the first token of each sentence # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation if self.model.training: mask = mask.clone() mask[:, 0] = 0 return arc_scores, rel_scores, mask, puncts
def forward(self, batch: Dict[str, torch.Tensor] ): keys = 'token_length', 'begin_offset', 'end_offset', 'label_id' sent_lengths, gold_starts, gold_ends, gold_labels = [batch.get(k, None) for k in keys] masks = util.lengths_to_mask(sent_lengths) num_sentences, max_sent_length = masks.size() raw_embeddings = self.embed(batch, mask=masks) raw_embeddings = F.dropout(raw_embeddings, self.lexical_dropout, self.training) contextualized_embeddings = self.context_layer(raw_embeddings, masks) return self.decoder.decode(contextualized_embeddings, gold_starts, gold_ends, gold_labels, masks, max_sent_length, num_sentences, sent_lengths)
def forward(self, batch: dict) -> Dict[str, torch.Tensor]: batch['mask'] = mask = lengths_to_mask(batch['text_length']) return super().forward(batch, batch['spans'], batch.get('span_labels'), mask=mask)
def compute_mask(self, batch): lens = batch['token_length'] mask = lengths_to_mask(lens) return mask
def forward(self, batch): def square_average(h, span): y = pick_tensor_for_each_token(h, span, True) y = y.transpose_(1, 2) y = pick_tensor_for_each_token(y, span, True) y = y.transpose_(1, 2) return y def square_mask(mask: torch.Tensor): square = mask.unsqueeze(-1).expand(-1, -1, mask.size(-1)) return square & square.transpose(1, 2) def pick_last_concept(src_, last_concept_3d_): return src_.gather( 0, last_concept_3d_.expand([-1, -1, src_.size(-1)])) if self.squeeze: training = self.training if isinstance(self.bert_encoder, ContextualWordEmbeddingModule): curr_hidden, next_hidden, attention = self.run_single_transformer( self.bert_encoder.transformer, batch) alignment, arc_attention = attention[:, 0, :, :], attention[:, 1:, :, :].max( dim =1 )[0] token_span = batch['token_and_concept_token_span'] alignment = square_average(alignment, token_span) arc_attention = square_average(arc_attention, token_span) batch_size, src_len, _ = next_hidden.size() input_len = token_span.size(1) snt_len = batch['snt_len'].max() # [CLS] snt_len [SEP] [DUMMY] concep_len (src_len) concept_mask = batch['concept_mask'] token_mask = batch['token_mask'] masked_concept_alignment = torch.zeros( (batch_size, src_len, input_len), device=next_hidden.device) concept_len = concept_mask.sum(1) concept_only_mask = lengths_to_mask(concept_len) masked_concept_alignment[concept_only_mask] = alignment[ concept_mask] masked_concept_token_alignment = torch.zeros( (batch_size, snt_len, src_len), device=next_hidden.device) masked_concept_alignment.transpose_(1, 2) masked_concept_token_alignment[lengths_to_mask( token_mask.sum(1))] = masked_concept_alignment[token_mask] masked_concept_token_alignment = masked_concept_token_alignment.permute( [2, 0, 1]) square_concept_mask = square_mask(concept_mask) masked_arc_attention = torch.zeros( (batch_size, src_len, src_len), device=next_hidden.device) square_concept_only_mask = square_mask(concept_only_mask) masked_arc_attention[square_concept_only_mask] = arc_attention[ square_concept_mask] masked_arc_attention.transpose_(0, 1) next_hidden.transpose_(0, 1) curr_hidden.transpose_(0, 1) if not training: last_concept_offset = batch['last_concept_offset'] last_concept_3d = last_concept_offset.unsqueeze( 0).unsqueeze(-1) masked_concept_token_alignment = pick_last_concept( masked_concept_token_alignment, last_concept_3d) masked_arc_attention = pick_last_concept( masked_arc_attention, last_concept_3d) next_hidden = pick_last_concept(next_hidden, last_concept_3d) src_len = 1 outputs = self.arc_concept_decoder( masked_concept_token_alignment, masked_arc_attention, None, next_hidden, batch['copy_seq'], batch.get('concept_out', None), batch['rel'][1:] if training else None, batch_size, src_len, work=not training) rel_loss = self.relation_generator( next_hidden, curr_hidden, target_rel=batch['rel'][1:] if training else None, work=not training) return self.ret_squeeze_or_bart(batch, outputs, rel_loss, training) else: raise NotImplementedError() bert_embed, _ = self.bert_encoder(batch['bert_token'], batch['token_subword_index']) elif self.bart: training = self.training bart: BartModel = self.bert_encoder.transformer # Run encoder if not cached if training: encoder_attention_mask, encoder_hidden_states = self.run_bart_encoder( bart, batch) else: encoder_attention_mask = batch['encoder_attention_mask'] encoder_hidden_states = batch['encoder_hidden_states'] # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) decoder = bart.decoder concept_input_ids = batch['concept_input_ids'] decoder_padding_mask = concept_input_ids == self.tokenizer.tokenizer.pad_token_id decoder_mask = batch['decoder_mask'] decoder_causal_mask = torch.zeros_like(decoder_mask, device=decoder_mask.device, dtype=torch.float) decoder_causal_mask.masked_fill_(~decoder_mask, float('-inf')) decoder_causal_mask.unsqueeze_(1) decoder_outputs = decoder( concept_input_ids, encoder_hidden_states, encoder_attention_mask, decoder_padding_mask, decoder_causal_mask=decoder_causal_mask, past_key_values=None, use_cache=False, output_attentions=True, output_hidden_states=True, return_dict=True, ) next_hidden = decoder_outputs.last_hidden_state token_span = batch['token_token_span'] concept_span = batch['concept_token_span'] next_hidden = pick_tensor_for_each_token(next_hidden, concept_span, True) arc_attention, alignment = decoder_outputs.attentions[-1] arc_attention = arc_attention.max(dim=1)[0] alignment = alignment.max(dim=1)[0] alignment = pick_tensor_for_each_token(alignment, concept_span, True) alignment.transpose_(1, 2) alignment = pick_tensor_for_each_token(alignment, token_span, True) # Strip [CLS] as it won't be copied anyway alignment = alignment[:, 1:, :] arc_attention = square_average(arc_attention, concept_span) batch_size, src_len, _ = next_hidden.size() next_hidden.transpose_(0, 1) alignment = alignment.permute(2, 0, 1) arc_attention.transpose_(0, 1) curr_hidden = decoder_outputs.hidden_states[-2] curr_hidden = pick_tensor_for_each_token(curr_hidden, concept_span, True) curr_hidden.transpose_(0, 1) if not training: last_concept_offset = batch['last_concept_offset'] last_concept_3d = last_concept_offset.unsqueeze(0).unsqueeze( -1) alignment = pick_last_concept(alignment, last_concept_3d) arc_attention = pick_last_concept(arc_attention, last_concept_3d) next_hidden = pick_last_concept(next_hidden, last_concept_3d) src_len = 1 outputs = self.arc_concept_decoder( alignment, arc_attention, None, next_hidden, batch['copy_seq'], batch.get('concept_out', None), batch['rel'][1:] if training else None, batch_size, src_len, work=not training) rel_loss = self.relation_generator( next_hidden, curr_hidden, target_rel=batch['rel'][1:] if training else None, work=not training) return self.ret_squeeze_or_bart(batch, outputs, rel_loss, training) if self.bert_encoder is not None: word_repr, word_mask, probe = self.encode_step_with_bert( batch['tok'], batch['lem'], batch['pos'], batch['ner'], batch['word_char'], batch=batch) else: word_repr, word_mask, probe = self.encode_step( batch['tok'], batch['lem'], batch['pos'], batch['ner'], batch['word_char']) concept_repr = self.embed_scale * self.concept_encoder( batch['concept_char_in'], batch['concept_in']) + self.embed_positions(batch['concept_in']) concept_repr = self.concept_embed_layer_norm(concept_repr) concept_repr = F.dropout(concept_repr, p=self.dropout, training=self.training) concept_mask = torch.eq(batch['concept_in'], self.concept_pad_idx) attn_mask = self.self_attn_mask(batch['concept_in'].size(0)) # concept_repr = self.graph_encoder(concept_repr, # self_padding_mask=concept_mask, self_attn_mask=attn_mask, # external_memories=word_repr, external_padding_mask=word_mask) for idx, layer in enumerate(self.graph_encoder.layers): concept_repr, arc_weight, _ = layer( concept_repr, self_padding_mask=concept_mask, self_attn_mask=attn_mask, external_memories=word_repr, external_padding_mask=word_mask, need_weights='max') graph_target_rel = batch['rel'][:-1] graph_target_arc = torch.ne(graph_target_rel, self.rel_nil_idx) # 0 or 1 graph_arc_mask = torch.eq(graph_target_rel, self.rel_pad_idx) graph_arc_loss = F.binary_cross_entropy(arc_weight, graph_target_arc.float(), reduction='none') graph_arc_loss = graph_arc_loss.masked_fill_(graph_arc_mask, 0.).sum( (0, 2)) if self.decoder.joint_arc_concept: probe: torch.Tensor = probe.expand( word_repr.size(0) + concept_repr.size(0), -1, -1) else: probe = probe.expand_as(concept_repr) # tgt_len x bsz x embed_dim concept_loss, arc_loss, rel_loss = self.decoder(probe, word_repr, concept_repr, word_mask, concept_mask, attn_mask, \ batch['copy_seq'], target=batch['concept_out'], target_rel=batch['rel'][1:]) concept_tot = concept_mask.size(0) - concept_mask.float().sum(0) if self.decoder.joint_arc_concept: concept_loss, concept_correct, concept_total = concept_loss if rel_loss is not None: rel_loss, rel_correct, rel_total = rel_loss rel_loss = rel_loss / concept_tot concept_loss = concept_loss / concept_tot arc_loss = arc_loss / concept_tot graph_arc_loss = graph_arc_loss / concept_tot if self.decoder.joint_arc_concept: # noinspection PyUnboundLocalVariable if rel_loss is not None: rel_out = (rel_loss.mean(), rel_correct, rel_total) else: rel_out = None return (concept_loss.mean(), concept_correct, concept_total ), arc_loss.mean(), rel_out, graph_arc_loss.mean() return concept_loss.mean(), arc_loss.mean(), rel_loss.mean( ), graph_arc_loss.mean()