Example #1
0
    def forward(self, sequence_tensor, span_indices):
        """Propagates forwardly

        Arguments:
            sequence_tensor {tensor} -- sequence tensor
            span_indices {tensor} -- span index tensor

        Returns:
            tensor -- span convolution embedding
        """

        # both of tensors' shape: (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape: (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape: (1, 1, max_batch_span_width)
        max_span_range_indices = get_range_vector(
            max_batch_span_width,
            get_device_of(sequence_tensor)).view(1, 1, -1)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using < here (and for the mask below) because the span ends are
        # not inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_indices_mask = (max_span_range_indices < span_widths).long()

        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This operation just like reversing the arrangement of (continually) span indices
        raw_span_indices = (span_ends - 1 - max_span_range_indices)
        # Using RElU function remove these elements which are smaller than zero
        span_indices = F.relu(raw_span_indices.float()).long()

        # Shape: (batch_size, num_spans, max_batch_span_width, embeding_dim)
        # Firstly call flatten_and_batch_shift_indices transforms span_indices
        # Then selects indexed embedding
        span_embedding = batched_index_select(sequence_tensor, span_indices)

        batch_size, num_spans, _, _ = span_embedding.size()
        span_conv_embedding = self.cnn_encoder(
            inputs=span_embedding.view(batch_size * num_spans,
                                       max_batch_span_width, -1),
            mask=span_indices_mask.view(batch_size * num_spans,
                                        max_batch_span_width))
        span_conv_embedding = span_conv_embedding.view(batch_size, num_spans,
                                                       -1)
        return span_conv_embedding
Example #2
0
    def get_entity_mention_feature(self, batch_wordpiece_tokens,
                                   batch_wordpiece_tokens_index,
                                   batch_entity_mentions, batch_seq_lens):
        """This function extracts entity mention feature using CNN.
        
        Arguments:
            batch_wordpiece_tokens {tensor} -- batch wordpiece tokens
            batch_wordpiece_tokens_index {tensor} -- batch wordpiece tokens index
            batch_entity_mentions {list} -- batch entity mentions
            batch_seq_lens {list} -- batch sequence length list
        
        Returns:
            tensor -- entity mention feature
        """

        batch_seq_reprs, _ = self.bert_encoder(batch_wordpiece_tokens)
        batch_seq_reprs = batched_index_select(batch_seq_reprs,
                                               batch_wordpiece_tokens_index)

        entity_spans = []
        context_spans = []
        for entity_mention, seq_len in zip(batch_entity_mentions,
                                           batch_seq_lens):
            entity_spans.append([[entity_mention[0][0], entity_mention[0][1]],
                                 [entity_mention[1][0], entity_mention[1][1]]])
            context_spans.append([[0, entity_mention[0][0]],
                                  [entity_mention[0][1], entity_mention[1][0]],
                                  [entity_mention[1][1], seq_len]])

        entity_spans_tensor = torch.LongTensor(entity_spans)
        if self.device > -1:
            entity_spans_tensor = entity_spans_tensor.cuda(device=self.device,
                                                           non_blocking=True)

        context_spans_tensor = torch.LongTensor(context_spans)
        if self.device > -1:
            context_spans_tensor = context_spans_tensor.cuda(
                device=self.device, non_blocking=True)

        entity_feature = self.entity_span_extractor(batch_seq_reprs,
                                                    entity_spans_tensor)
        context_feature = self.context_span_extractor(batch_seq_reprs,
                                                      context_spans_tensor)

        entity_feature = self.ent2hidden(entity_feature)
        context_feature = self.context2hidden(context_feature)

        entity_mention_feature = torch.cat([
            context_feature[:, 0, :], entity_feature[:, 0, :],
            context_feature[:, 1, :], entity_feature[:, 1, :],
            context_feature[:, 2, :]
        ],
                                           dim=-1).view(
                                               len(batch_wordpiece_tokens), -1)

        entity_mention_feature = self.mlp(entity_mention_feature)

        return entity_mention_feature
Example #3
0
    def forward(self, batch_inputs):
        """This function propagates forwardly

        Arguments:
            batch_inputs {dict} -- batch input data
        """

        batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder(batch_inputs['wordpiece_tokens'])

        batch_seq_tokens_encoder_repr = batched_index_select(batch_seq_bert_encoder_repr,
                                                             batch_inputs['wordpiece_tokens_index'])

        batch_inputs['seq_encoder_reprs'] = batch_seq_tokens_encoder_repr
        batch_inputs['seq_cls_repr'] = batch_cls_repr
Example #4
0
    def forward(self, batch_inputs):
        """This function propagetes forwardly

        Arguments:
            batch_inputs {dict} -- batch inputs
        
        Returns:
            dict -- results
        """

        batch_seq_wordpiece_tokens_repr, batch_seq_cls_repr = self.bert_encoder(
            batch_inputs['wordpiece_tokens'])
        batch_seq_tokens_repr = batched_index_select(batch_seq_wordpiece_tokens_repr,
                                                     batch_inputs['wordpiece_tokens_index'])

        results = {}

        entity_feature = self.entity_span_extractor(batch_seq_tokens_repr,
                                                    batch_inputs['span_mention'])
        entity_feature = self.ent2hidden(entity_feature)

        subj_pos = torch.LongTensor([-1, 0, 1, 2, 3]) + 3
        obj_pos = torch.LongTensor([-3, -2, -1, 0, 1]) + 3

        if self.device > -1:
            subj_pos = subj_pos.cuda(device=self.device, non_blocking=True)
            obj_pos = obj_pos.cuda(device=self.device, non_blocking=True)

        subj_pos_emb = self.position_embedding(subj_pos)
        obj_pos_emb = self.position_embedding(obj_pos)
        pos_emb = torch.cat([subj_pos_emb, obj_pos_emb], dim=1).unsqueeze(0).repeat(
            batch_inputs['wordpiece_tokens_index'].size()[0], 1, 1)

        span_mention_attention_repr = self.attention_encoder(inputs=entity_feature,
                                                             query=batch_seq_cls_repr,
                                                             feature=pos_emb)
        results['span_mention_repr'] = self.mlp_head2(self.mlp_head1(span_mention_attention_repr))

        if self.momentum:
            return results

        zero_loss = torch.Tensor([0])
        zero_loss.requires_grad = True
        if self.device > -1:
            zero_loss = zero_loss.cuda(device=self.device, non_blocking=True)

        if sum([len(masked_index) for masked_index in batch_inputs['masked_index']]) == 0:
            results['masked_token_loss'] = zero_loss
        else:
            masked_wordpiece_tokens_repr = []
            all_masked_label = []
            for masked_index, masked_position, masked_label, seq_wordpiece_tokens_repr in zip(
                    batch_inputs['masked_index'], batch_inputs['masked_position'],
                    batch_inputs['masked_label'], batch_seq_wordpiece_tokens_repr):
                masked_index_tensor = torch.LongTensor(masked_index)
                masked_position_tensor = torch.LongTensor(masked_position)

                if self.device > -1:
                    masked_index_tensor = masked_index_tensor.cuda(device=self.device,
                                                                   non_blocking=True)
                    masked_position_tensor = masked_position_tensor.cuda(device=self.device,
                                                                         non_blocking=True)

                masked_wordpiece_tokens_repr.append(
                    torch.cat([
                        seq_wordpiece_tokens_repr[masked_index_tensor],
                        self.global_position_embedding(masked_position_tensor)
                    ],
                              dim=1))
                all_masked_label.extend(masked_label)

            masked_wordpiece_tokens_input = torch.cat(masked_wordpiece_tokens_repr, dim=0)
            masked_wordpiece_tokens_output = self.masked_token_decoder(
                self.masked_token_mlp(
                    masked_wordpiece_tokens_input)) + self.masked_token_decoder_bias

            all_masked_label_tensor = torch.LongTensor(all_masked_label)
            if self.device > -1:
                all_masked_label_tensor = all_masked_label_tensor.cuda(device=self.device,
                                                                       non_blocking=True)
            results['masked_token_loss'] = self.masked_token_loss(masked_wordpiece_tokens_output,
                                                                  all_masked_label_tensor)

        all_spans = []
        all_spans_label = []
        all_seq_tokens_reprs = []
        for spans, spans_label, seq_tokens_repr in zip(batch_inputs['spans'],
                                                       batch_inputs['spans_label'],
                                                       batch_seq_tokens_repr):
            all_spans.extend(spans)
            all_spans_label.extend(spans_label)
            all_seq_tokens_reprs.extend(seq_tokens_repr for _ in range(len(spans)))

        assert len(all_spans) == len(all_seq_tokens_reprs) and len(all_spans) == len(
            all_spans_label)

        if len(all_spans) == 0:
            results['span_loss'] = zero_loss
        else:
            if self.span_batch_size > 0:
                all_span_loss = []
                for idx in range(0, len(all_spans), self.span_batch_size):
                    batch_ents_tensor = torch.LongTensor(
                        all_spans[idx:idx + self.span_batch_size]).unsqueeze(1)
                    if self.device > -1:
                        batch_ents_tensor = batch_ents_tensor.cuda(device=self.device,
                                                                   non_blocking=True)

                    batch_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs[idx:idx +
                                                                              self.span_batch_size])

                    batch_spans_feature = self.ent2hidden(
                        self.entity_span_extractor(batch_seq_tokens_reprs,
                                                   batch_ents_tensor).squeeze(1))

                    batch_spans_label = torch.LongTensor(all_spans_label[idx:idx +
                                                                         self.span_batch_size])
                    if self.device > -1:
                        batch_spans_label = batch_spans_label.cuda(device=self.device,
                                                                   non_blocking=True)

                    span_outputs = self.entity_span_decoder(
                        self.entity_span_mlp(batch_spans_feature), batch_spans_label)

                    all_span_loss.append(span_outputs['loss'])
                results['span_loss'] = sum(all_span_loss) / len(all_span_loss)
            else:
                all_spans_tensor = torch.LongTensor(all_spans).unsqueeze(1)
                if self.device > -1:
                    all_spans_tensor = all_spans_tensor.cuda(device=self.device, non_blocking=True)
                all_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs)
                all_spans_feature = self.entity_span_extractor(all_seq_tokens_reprs,
                                                               all_spans_tensor).squeeze(1)

                all_spans_feature = self.ent2hidden(all_spans_feature)

                all_spans_label = torch.LongTensor(all_spans_label)
                if self.device > -1:
                    all_spans_label = all_spans_label.cuda(device=self.device, non_blocking=True)

                entity_typing_outputs = self.entity_span_decoder(
                    self.entity_span_mlp(all_spans_feature), all_spans_label)

                results['span_loss'] = entity_typing_outputs['loss']

        return results
Example #5
0
    def masked_entity_typing(self, batch_inputs):
        """This function pretrains masked entity typing task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        seq_wordpiece_tokens_reprs, _ = self.bert_encoder(
            batch_inputs['tokens_id'])
        batch_inputs['seq_wordpiece_tokens_reprs'] = seq_wordpiece_tokens_reprs
        batch_inputs['seq_tokens_reprs'] = batched_index_select(
            seq_wordpiece_tokens_reprs, batch_inputs['tokens_index'])

        all_ents = []
        all_ents_labels = []
        all_seq_tokens_reprs = []
        for ent_spans, ent_labels, seq_tokens_reprs in zip(
                batch_inputs['ent_spans'], batch_inputs['ent_labels'],
                batch_inputs['seq_tokens_reprs']):
            all_ents.extend([span[0], span[1] - 1] for span in ent_spans)
            all_ents_labels.extend(ent_label for ent_label in ent_labels)
            all_seq_tokens_reprs.extend(seq_tokens_reprs
                                        for _ in range(len(ent_spans)))

        if self.ent_batch_size > 0:
            all_entity_typing_loss = []
            for idx in range(0, len(all_ents), self.ent_batch_size):
                batch_ents_tensor = torch.LongTensor(
                    all_ents[idx:idx + self.ent_batch_size]).unsqueeze(1)
                if self.device > -1:
                    batch_ents_tensor = batch_ents_tensor.cuda(
                        device=self.device, non_blocking=True)

                batch_seq_tokens_reprs = torch.stack(
                    all_seq_tokens_reprs[idx:idx + self.ent_batch_size])

                batch_ents_feature = self.ent2hidden(
                    self.entity_span_extractor(batch_seq_tokens_reprs,
                                               batch_ents_tensor).squeeze(1))

                batch_ents_labels = torch.LongTensor(
                    all_ents_labels[idx:idx + self.ent_batch_size])
                if self.device > -1:
                    batch_ents_labels = batch_ents_labels.cuda(
                        device=self.device, non_blocking=True)

                entity_typing_outputs = self.entity_pretrained_decoder(
                    batch_ents_feature, batch_ents_labels)

                all_entity_typing_loss.append(entity_typing_outputs['loss'])

            if len(all_entity_typing_loss) != 0:
                entity_typing_loss = sum(all_entity_typing_loss) / len(
                    all_entity_typing_loss)
            else:
                zero_loss = torch.Tensor([0])
                zero_loss.requires_grad = True
                if self.device > -1:
                    zero_loss = zero_loss.cuda(device=self.device,
                                               non_blocking=True)
                entity_typing_loss = zero_loss
        else:
            all_ents_tensor = torch.LongTensor(all_ents).unsqueeze(1)
            if self.device > -1:
                all_ents_tensor = all_ents_tensor.cuda(device=self.device,
                                                       non_blocking=True)
            all_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs)
            all_ents_feature = self.entity_span_extractor(
                all_seq_tokens_reprs, all_ents_tensor).squeeze(1)

            all_ents_feature = self.ent2hidden(all_ents_feature)

            all_ents_labels = torch.LongTensor(all_ents_labels)
            if self.device > -1:
                all_ents_labels = all_ents_labels.cuda(device=self.device,
                                                       non_blocking=True)

            entity_typing_outputs = self.entity_pretrained_decoder(
                all_ents_feature, all_ents_labels)

            entity_typing_loss = entity_typing_outputs['loss']

        outputs = {}
        outputs['loss'] = entity_typing_loss

        return outputs