def get_srl_loss_mask(self, srl_scores, num_predicted_args, num_predicted_preds):
     max_num_arg = srl_scores.size()[1]
     max_num_pred = srl_scores.size()[2]
     # num_predicted_args, 1D tensor; max_num_arg: a int variable means the gold ans's max arg number
     args_mask = util.lengths_to_mask(num_predicted_args, max_num_arg)
     pred_mask = util.lengths_to_mask(num_predicted_preds, max_num_pred)
     srl_loss_mask = args_mask.unsqueeze(2) & pred_mask.unsqueeze(1)
     return srl_loss_mask
Example #2
0
 def compute_loss(self, batch: Dict[str, Any],
                  output: Union[torch.Tensor, Dict[str, torch.Tensor],
                                Iterable[torch.Tensor], Any], criterion):
     mask = util.lengths_to_mask(batch['text_input_ids_length'])
     return criterion(
         output[mask],
         batch['text_prefix_mask'][:, 1:-1][mask].to(torch.float))
Example #3
0
 def feed_batch(self,
                batch: Dict[str, Any],
                task_name,
                output_dict=None,
                run_transform=False,
                cls_is_bos=False,
                sep_is_eos=False,
                results=None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
     h, output_dict = self._encode(batch, task_name, output_dict,
                                   cls_is_bos, sep_is_eos)
     task = self.tasks[task_name]
     if run_transform:
         batch = task.transform_batch(batch,
                                      results=results,
                                      cls_is_bos=cls_is_bos,
                                      sep_is_eos=sep_is_eos)
     batch['mask'] = mask = util.lengths_to_mask(batch['token_length'])
     output_dict[task_name] = {
         'output':
         task.feed_batch(h,
                         batch=batch,
                         mask=mask,
                         decoder=self.model.decoders[task_name]),
         'mask':
         mask
     }
     return output_dict, batch
Example #4
0
    def decode(self, contextualized_embeddings, gold_starts, gold_ends, gold_labels, masks, max_sent_length,
               num_sentences, sent_lengths):
        # Apply MLPs to starts and ends, [num_sentences, max_sentences_length,emb]
        candidate_starts_emb = self.start_mlp(contextualized_embeddings)
        candidate_ends_emb = self.end_mlp(contextualized_embeddings)
        candidate_ner_scores = self.biaffine(candidate_starts_emb, candidate_ends_emb).permute([0, 2, 3, 1])

        """generate candidate spans with argument pruning"""
        # Generate masks
        candidate_scores_mask = masks.unsqueeze(1) & masks.unsqueeze(2)
        device = sent_lengths.device
        sentence_ends_leq_starts = (
            ~util.lengths_to_mask(torch.arange(max_sent_length, device=device), max_sent_length)) \
            .unsqueeze_(0).expand(num_sentences, -1, -1)
        candidate_scores_mask &= sentence_ends_leq_starts
        candidate_ner_scores = candidate_ner_scores[candidate_scores_mask]
        predict_dict = {
            "candidate_ner_scores": candidate_ner_scores,

        }
        if gold_starts is not None:
            gold_ner_labels = self.get_dense_span_labels(gold_starts, gold_ends, gold_labels, max_sent_length)
            loss = torch.nn.functional.cross_entropy(candidate_ner_scores,
                                                     gold_ner_labels[candidate_scores_mask],
                                                     reduction=self.loss_reduction)
            predict_dict['loss'] = loss
        return predict_dict
Example #5
0
 def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None):
     if mask is None:
         mask = util.lengths_to_mask(batch['token_length'])
     else:
         mask = mask.clone()
     scores = self.decoder(contextualized_embeddings, mask)
     mask[:, 0] = 0
     return scores, mask
 def unpack(batch, mask=None, training=False):
     keys = 'token_length', 'predicate_offset', 'argument_begin_offset', 'argument_end_offset', 'srl_label_id'
     sent_lengths, gold_predicates, gold_arg_starts, gold_arg_ends, gold_arg_labels = [batch.get(k, None) for k in
                                                                                       keys]
     if mask is None:
         mask = util.lengths_to_mask(sent_lengths)
     # elif not training:
     #     sent_lengths = mask.sum(dim=1)
     return gold_arg_ends, gold_arg_labels, gold_arg_starts, gold_predicates, mask, sent_lengths
Example #7
0
 def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None):
     keys = 'token_length', 'begin_offset', 'end_offset', 'label_id'
     sent_lengths, gold_starts, gold_ends, gold_labels = [batch.get(k, None) for k in keys]
     if mask is None:
         mask = util.lengths_to_mask(sent_lengths)
     num_sentences, max_sent_length = mask.size()
     return self.decode(contextualized_embeddings, gold_starts, gold_ends, gold_labels, mask,
                        max_sent_length,
                        num_sentences, sent_lengths)
Example #8
0
 def feed_batch(self, batch):
     words, feats, lens, puncts = batch.get('token_id', None), batch.get('pos_id', None), batch['sent_length'], \
                                  batch.get('punct_mask', None)
     mask = lengths_to_mask(lens)
     logits = self.model(batch, mask)
     if self.model.training:
         mask = mask.clone()
     # ignore the first token of each sentence
     mask[:, 0] = 0
     return logits, mask, puncts
Example #9
0
 def feed_batch(self, batch: dict):
     lens = batch['token_length']
     mask2d = lengths_to_mask(lens)
     pred = self.model(batch, mask=mask2d)
     mask3d = self.compute_mask(mask2d)
     if self.config.crf:
         token_index = mask3d[0]
         pred = pred.flatten(end_dim=1)[token_index]
         pred = F.log_softmax(pred, dim=-1)
     return pred, mask3d
Example #10
0
 def pool_subtoken(last_hidden_state, token_span, concept_mask,
                   concept_lens, max_concept_len):
     last_hidden_state = pick_tensor_for_each_token(last_hidden_state,
                                                    token_span, True)
     concept_hidden = torch.zeros((
         concept_mask.size(0),
         max_concept_len,
         last_hidden_state.size(-1),
     ),
                                  device=concept_mask.device)
     concept_hidden[lengths_to_mask(
         concept_lens, max_concept_len)] = last_hidden_state[concept_mask]
     return concept_hidden
Example #11
0
 def forward(self,
             lens: torch.LongTensor,
             input_ids,
             token_span,
             token_type_ids=None):
     mask = lengths_to_mask(lens)
     x = self.encoder(input_ids,
                      token_span=token_span,
                      token_type_ids=token_type_ids)
     if self.secondary_encoder:
         x = self.secondary_encoder(x, mask=mask)
     x = self.classifier(x)
     return x, mask
Example #12
0
    def get_span_emb(self,
                     flatted_context_emb,
                     flatted_candidate_starts,
                     flatted_candidate_ends,
                     config,
                     dropout=0.0):
        batch_word_num = flatted_context_emb.size()[0]
        # gather slices from embeddings according to indices
        span_start_emb = flatted_context_emb[flatted_candidate_starts]
        span_end_emb = flatted_context_emb[flatted_candidate_ends]
        span_emb_feature_list = [
            span_start_emb, span_end_emb
        ]  # store the span vector representations for span rep.

        span_width = 1 + flatted_candidate_ends - flatted_candidate_starts  # [num_spans], generate the span width
        max_arg_width = config.max_arg_width

        # get the span width feature emb
        span_width_index = span_width - 1
        span_width_emb = self.span_width_embedding(span_width_index)
        span_width_emb = F.dropout(span_width_emb, dropout, self.training)
        span_emb_feature_list.append(span_width_emb)
        """head features"""
        cpu_flatted_candidte_starts = flatted_candidate_starts
        span_indices = torch.arange(0, max_arg_width, device=flatted_context_emb.device).view(1, -1) + \
                       cpu_flatted_candidte_starts.view(-1, 1)  # For all the i, where i in [begin, ..i, end] for span
        # reset the position index to the batch_word_num index with index - 1
        span_indices = torch.clamp(span_indices, max=batch_word_num - 1)
        num_spans, spans_width = span_indices.size()[0], span_indices.size()[1]
        flatted_span_indices = span_indices.view(
            -1)  # so Huge!!!, column is the span?
        # if torch.cuda.is_available():
        flatted_span_indices = flatted_span_indices
        span_text_emb = flatted_context_emb.index_select(
            0, flatted_span_indices).view(num_spans, spans_width, -1)
        span_indices_mask = util.lengths_to_mask(span_width,
                                                 max_len=max_arg_width)
        # project context output to num head
        # head_scores = self.context_projective_layer.forward(flatted_context_emb)
        # get span attention
        # span_attention = head_scores.index_select(0, flatted_span_indices).view(num_spans, spans_width)
        # span_attention = torch.add(span_attention, expanded_span_indices_log_mask).unsqueeze(2)  # control the span len
        # span_attention = F.softmax(span_attention, dim=1)
        span_text_emb = span_text_emb * span_indices_mask.unsqueeze(2).expand(
            -1, -1,
            span_text_emb.size()[-1])
        span_head_emb = torch.mean(span_text_emb, 1)
        span_emb_feature_list.append(span_head_emb)

        span_emb = torch.cat(span_emb_feature_list, 1)
        return span_emb, None, span_text_emb, span_indices, span_indices_mask
Example #13
0
 def feed_batch(self, batch):
     words, feats, lens, puncts = batch.get('token_id', None), batch.get('pos_id', None), batch['sent_length'], \
                                  batch.get('punct_mask', None)
     mask = lengths_to_mask(lens)
     arc_scores, rel_scores = self.model(words=words,
                                         feats=feats,
                                         mask=mask,
                                         batch=batch,
                                         **batch)
     # ignore the first token of each sentence
     # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
     if self.model.training:
         mask = mask.clone()
     mask[:, 0] = 0
     return arc_scores, rel_scores, mask, puncts
Example #14
0
    def forward(self,
                batch: Dict[str, torch.Tensor]
                ):
        keys = 'token_length', 'begin_offset', 'end_offset', 'label_id'
        sent_lengths, gold_starts, gold_ends, gold_labels = [batch.get(k, None) for k in keys]
        masks = util.lengths_to_mask(sent_lengths)
        num_sentences, max_sent_length = masks.size()
        raw_embeddings = self.embed(batch, mask=masks)

        raw_embeddings = F.dropout(raw_embeddings, self.lexical_dropout, self.training)

        contextualized_embeddings = self.context_layer(raw_embeddings, masks)
        return self.decoder.decode(contextualized_embeddings, gold_starts, gold_ends, gold_labels, masks,
                                   max_sent_length,
                                   num_sentences, sent_lengths)
Example #15
0
 def forward(self, batch: dict) -> Dict[str, torch.Tensor]:
     batch['mask'] = mask = lengths_to_mask(batch['text_length'])
     return super().forward(batch,
                            batch['spans'],
                            batch.get('span_labels'),
                            mask=mask)
Example #16
0
 def compute_mask(self, batch):
     lens = batch['token_length']
     mask = lengths_to_mask(lens)
     return mask
Example #17
0
    def forward(self, batch):
        def square_average(h, span):
            y = pick_tensor_for_each_token(h, span, True)
            y = y.transpose_(1, 2)
            y = pick_tensor_for_each_token(y, span, True)
            y = y.transpose_(1, 2)
            return y

        def square_mask(mask: torch.Tensor):
            square = mask.unsqueeze(-1).expand(-1, -1, mask.size(-1))
            return square & square.transpose(1, 2)

        def pick_last_concept(src_, last_concept_3d_):
            return src_.gather(
                0, last_concept_3d_.expand([-1, -1, src_.size(-1)]))

        if self.squeeze:
            training = self.training
            if isinstance(self.bert_encoder, ContextualWordEmbeddingModule):
                curr_hidden, next_hidden, attention = self.run_single_transformer(
                    self.bert_encoder.transformer, batch)
                alignment, arc_attention = attention[:,
                                                     0, :, :], attention[:,
                                                                         1:, :, :].max(
                                                                             dim
                                                                             =1
                                                                         )[0]
                token_span = batch['token_and_concept_token_span']
                alignment = square_average(alignment, token_span)
                arc_attention = square_average(arc_attention, token_span)
                batch_size, src_len, _ = next_hidden.size()
                input_len = token_span.size(1)
                snt_len = batch['snt_len'].max()
                # [CLS] snt_len [SEP] [DUMMY] concep_len (src_len)
                concept_mask = batch['concept_mask']
                token_mask = batch['token_mask']

                masked_concept_alignment = torch.zeros(
                    (batch_size, src_len, input_len),
                    device=next_hidden.device)
                concept_len = concept_mask.sum(1)
                concept_only_mask = lengths_to_mask(concept_len)
                masked_concept_alignment[concept_only_mask] = alignment[
                    concept_mask]
                masked_concept_token_alignment = torch.zeros(
                    (batch_size, snt_len, src_len), device=next_hidden.device)
                masked_concept_alignment.transpose_(1, 2)
                masked_concept_token_alignment[lengths_to_mask(
                    token_mask.sum(1))] = masked_concept_alignment[token_mask]
                masked_concept_token_alignment = masked_concept_token_alignment.permute(
                    [2, 0, 1])

                square_concept_mask = square_mask(concept_mask)
                masked_arc_attention = torch.zeros(
                    (batch_size, src_len, src_len), device=next_hidden.device)
                square_concept_only_mask = square_mask(concept_only_mask)
                masked_arc_attention[square_concept_only_mask] = arc_attention[
                    square_concept_mask]
                masked_arc_attention.transpose_(0, 1)

                next_hidden.transpose_(0, 1)
                curr_hidden.transpose_(0, 1)
                if not training:
                    last_concept_offset = batch['last_concept_offset']
                    last_concept_3d = last_concept_offset.unsqueeze(
                        0).unsqueeze(-1)

                    masked_concept_token_alignment = pick_last_concept(
                        masked_concept_token_alignment, last_concept_3d)
                    masked_arc_attention = pick_last_concept(
                        masked_arc_attention, last_concept_3d)
                    next_hidden = pick_last_concept(next_hidden,
                                                    last_concept_3d)
                    src_len = 1
                outputs = self.arc_concept_decoder(
                    masked_concept_token_alignment,
                    masked_arc_attention,
                    None,
                    next_hidden,
                    batch['copy_seq'],
                    batch.get('concept_out', None),
                    batch['rel'][1:] if training else None,
                    batch_size,
                    src_len,
                    work=not training)
                rel_loss = self.relation_generator(
                    next_hidden,
                    curr_hidden,
                    target_rel=batch['rel'][1:] if training else None,
                    work=not training)
                return self.ret_squeeze_or_bart(batch, outputs, rel_loss,
                                                training)
            else:
                raise NotImplementedError()
                bert_embed, _ = self.bert_encoder(batch['bert_token'],
                                                  batch['token_subword_index'])
        elif self.bart:
            training = self.training
            bart: BartModel = self.bert_encoder.transformer

            # Run encoder if not cached
            if training:
                encoder_attention_mask, encoder_hidden_states = self.run_bart_encoder(
                    bart, batch)
            else:
                encoder_attention_mask = batch['encoder_attention_mask']
                encoder_hidden_states = batch['encoder_hidden_states']

            # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
            decoder = bart.decoder
            concept_input_ids = batch['concept_input_ids']
            decoder_padding_mask = concept_input_ids == self.tokenizer.tokenizer.pad_token_id
            decoder_mask = batch['decoder_mask']
            decoder_causal_mask = torch.zeros_like(decoder_mask,
                                                   device=decoder_mask.device,
                                                   dtype=torch.float)
            decoder_causal_mask.masked_fill_(~decoder_mask, float('-inf'))
            decoder_causal_mask.unsqueeze_(1)
            decoder_outputs = decoder(
                concept_input_ids,
                encoder_hidden_states,
                encoder_attention_mask,
                decoder_padding_mask,
                decoder_causal_mask=decoder_causal_mask,
                past_key_values=None,
                use_cache=False,
                output_attentions=True,
                output_hidden_states=True,
                return_dict=True,
            )
            next_hidden = decoder_outputs.last_hidden_state
            token_span = batch['token_token_span']
            concept_span = batch['concept_token_span']
            next_hidden = pick_tensor_for_each_token(next_hidden, concept_span,
                                                     True)
            arc_attention, alignment = decoder_outputs.attentions[-1]
            arc_attention = arc_attention.max(dim=1)[0]
            alignment = alignment.max(dim=1)[0]
            alignment = pick_tensor_for_each_token(alignment, concept_span,
                                                   True)
            alignment.transpose_(1, 2)
            alignment = pick_tensor_for_each_token(alignment, token_span, True)
            # Strip [CLS] as it won't be copied anyway
            alignment = alignment[:, 1:, :]
            arc_attention = square_average(arc_attention, concept_span)
            batch_size, src_len, _ = next_hidden.size()
            next_hidden.transpose_(0, 1)
            alignment = alignment.permute(2, 0, 1)
            arc_attention.transpose_(0, 1)

            curr_hidden = decoder_outputs.hidden_states[-2]
            curr_hidden = pick_tensor_for_each_token(curr_hidden, concept_span,
                                                     True)
            curr_hidden.transpose_(0, 1)

            if not training:
                last_concept_offset = batch['last_concept_offset']
                last_concept_3d = last_concept_offset.unsqueeze(0).unsqueeze(
                    -1)

                alignment = pick_last_concept(alignment, last_concept_3d)
                arc_attention = pick_last_concept(arc_attention,
                                                  last_concept_3d)
                next_hidden = pick_last_concept(next_hidden, last_concept_3d)
                src_len = 1
            outputs = self.arc_concept_decoder(
                alignment,
                arc_attention,
                None,
                next_hidden,
                batch['copy_seq'],
                batch.get('concept_out', None),
                batch['rel'][1:] if training else None,
                batch_size,
                src_len,
                work=not training)
            rel_loss = self.relation_generator(
                next_hidden,
                curr_hidden,
                target_rel=batch['rel'][1:] if training else None,
                work=not training)
            return self.ret_squeeze_or_bart(batch, outputs, rel_loss, training)
        if self.bert_encoder is not None:
            word_repr, word_mask, probe = self.encode_step_with_bert(
                batch['tok'],
                batch['lem'],
                batch['pos'],
                batch['ner'],
                batch['word_char'],
                batch=batch)
        else:
            word_repr, word_mask, probe = self.encode_step(
                batch['tok'], batch['lem'], batch['pos'], batch['ner'],
                batch['word_char'])
        concept_repr = self.embed_scale * self.concept_encoder(
            batch['concept_char_in'],
            batch['concept_in']) + self.embed_positions(batch['concept_in'])
        concept_repr = self.concept_embed_layer_norm(concept_repr)
        concept_repr = F.dropout(concept_repr,
                                 p=self.dropout,
                                 training=self.training)
        concept_mask = torch.eq(batch['concept_in'], self.concept_pad_idx)
        attn_mask = self.self_attn_mask(batch['concept_in'].size(0))
        # concept_repr = self.graph_encoder(concept_repr,
        #                          self_padding_mask=concept_mask, self_attn_mask=attn_mask,
        #                          external_memories=word_repr, external_padding_mask=word_mask)
        for idx, layer in enumerate(self.graph_encoder.layers):
            concept_repr, arc_weight, _ = layer(
                concept_repr,
                self_padding_mask=concept_mask,
                self_attn_mask=attn_mask,
                external_memories=word_repr,
                external_padding_mask=word_mask,
                need_weights='max')

        graph_target_rel = batch['rel'][:-1]
        graph_target_arc = torch.ne(graph_target_rel,
                                    self.rel_nil_idx)  # 0 or 1
        graph_arc_mask = torch.eq(graph_target_rel, self.rel_pad_idx)
        graph_arc_loss = F.binary_cross_entropy(arc_weight,
                                                graph_target_arc.float(),
                                                reduction='none')
        graph_arc_loss = graph_arc_loss.masked_fill_(graph_arc_mask, 0.).sum(
            (0, 2))

        if self.decoder.joint_arc_concept:
            probe: torch.Tensor = probe.expand(
                word_repr.size(0) + concept_repr.size(0), -1, -1)
        else:
            probe = probe.expand_as(concept_repr)  # tgt_len x bsz x embed_dim
        concept_loss, arc_loss, rel_loss = self.decoder(probe, word_repr, concept_repr, word_mask, concept_mask,
                                                        attn_mask, \
                                                        batch['copy_seq'], target=batch['concept_out'],
                                                        target_rel=batch['rel'][1:])

        concept_tot = concept_mask.size(0) - concept_mask.float().sum(0)
        if self.decoder.joint_arc_concept:
            concept_loss, concept_correct, concept_total = concept_loss
        if rel_loss is not None:
            rel_loss, rel_correct, rel_total = rel_loss
            rel_loss = rel_loss / concept_tot
        concept_loss = concept_loss / concept_tot
        arc_loss = arc_loss / concept_tot
        graph_arc_loss = graph_arc_loss / concept_tot
        if self.decoder.joint_arc_concept:
            # noinspection PyUnboundLocalVariable
            if rel_loss is not None:
                rel_out = (rel_loss.mean(), rel_correct, rel_total)
            else:
                rel_out = None
            return (concept_loss.mean(), concept_correct, concept_total
                    ), arc_loss.mean(), rel_out, graph_arc_loss.mean()
        return concept_loss.mean(), arc_loss.mean(), rel_loss.mean(
        ), graph_arc_loss.mean()