Exemplo n.º 1
0
 def len_of_sent(self, sent):
     # Transformer tokenizing
     config = self.transformer_config
     tokenizer = self.tokenizer
     xlnet = config_is(config, 'xlnet')
     roberta = config_is(config, 'roberta')
     pad_token = tokenizer.pad_token
     pad_token_id = tokenizer.convert_tokens_to_ids([pad_token])[0]
     cls_token = tokenizer.cls_token
     sep_token = tokenizer.sep_token
     max_seq_length = self.config.max_seq_length
     sent = sent[1:]  # remove <root> use [CLS] instead
     pad_label_idx = self.form_vocab.pad_idx
     sent = [x[0] for x in sent]
     sent = adjust_tokens_for_transformers(sent)
     input_ids, input_mask, segment_ids, prefix_mask = \
         convert_examples_to_features(sent,
                                      max_seq_length,
                                      tokenizer,
                                      cls_token_at_end=xlnet,
                                      # xlnet has a cls token at the end
                                      cls_token=cls_token,
                                      cls_token_segment_id=2 if xlnet else 0,
                                      sep_token=sep_token,
                                      sep_token_extra=roberta,
                                      # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                                      pad_on_left=xlnet,
                                      # pad on the left for xlnet
                                      pad_token_id=pad_token_id,
                                      pad_token_segment_id=4 if xlnet else 0,
                                      pad_token_label_id=pad_label_idx,
                                      do_padding=False)
     return len(input_ids)
Exemplo n.º 2
0
    def inputs_to_samples(self, inputs, gold=False):
        max_seq_length = self.config.get('max_seq_length', 128)
        tokenizer = self._tokenizer
        xlnet = False
        roberta = False
        pad_token = self.pad
        cls_token = '[CLS]'
        sep_token = '[SEP]'
        unk_token = self.unk

        pad_label_idx = self.tag_vocab.pad_idx
        pad_token = tokenizer.convert_tokens_to_ids([pad_token])[0]
        for sample in inputs:
            if gold:
                words, tags = sample
            else:
                words, tags = sample, [self.tag_vocab.idx_to_token[1]
                                       ] * len(sample)

            input_ids, input_mask, segment_ids, label_ids = convert_examples_to_features(
                words,
                max_seq_length,
                tokenizer,
                tags,
                self.tag_vocab.token_to_idx,
                cls_token_at_end=xlnet,
                # xlnet has a cls token at the end
                cls_token=cls_token,
                cls_token_segment_id=2 if xlnet else 0,
                sep_token=sep_token,
                sep_token_extra=roberta,
                # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                pad_on_left=xlnet,
                # pad on the left for xlnet
                pad_token_id=pad_token,
                pad_token_segment_id=4 if xlnet else 0,
                pad_token_label_id=pad_label_idx,
                unk_token=unk_token)

            if None in input_ids:
                print(input_ids)
            if None in input_mask:
                print(input_mask)
            if None in segment_ids:
                print(input_mask)
            yield (input_ids, input_mask, segment_ids), label_ids
Exemplo n.º 3
0
    def batched_inputs_to_batches(self, corpus, indices, shuffle):
        use_pos = self.use_pos
        if use_pos:
            raw_batch = [[], [], [], []]
        else:
            raw_batch = [[], [], []]
        if self.graph:
            max_len = len(max([corpus[i] for i in indices], key=len))
            for idx in indices:
                arc = np.zeros((max_len, max_len), dtype=np.bool)
                rel = np.zeros((max_len, max_len), dtype=np.int64)
                for b in raw_batch[:2 if use_pos else 1]:
                    b.append([])
                for m, cells in enumerate(corpus[idx]):
                    if use_pos:
                        for b, c, v in zip(raw_batch, cells, [None, self.cpos_vocab]):
                            b[-1].append(v.get_idx_without_add(c) if v else c)
                    else:
                        for b, c, v in zip(raw_batch, cells, [None]):
                            b[-1].append(c)
                    for n, r in zip(cells[-2], cells[-1]):
                        arc[m, n] = True
                        rid = self.rel_vocab.get_idx_without_add(r)
                        if rid is None:
                            logger.warning(f'Relation OOV: {r} not exists in train')
                            continue
                        rel[m, n] = rid
                raw_batch[-2].append(arc)
                raw_batch[-1].append(rel)
        else:
            for idx in indices:
                for s in raw_batch:
                    s.append([])
                for cells in corpus[idx]:
                    if use_pos:
                        for s, c, v in zip(raw_batch, cells, [None, self.cpos_vocab, None, self.rel_vocab]):
                            s[-1].append(v.get_idx_without_add(c) if v else c)
                    else:
                        for s, c, v in zip(raw_batch, cells, [None, None, self.rel_vocab]):
                            s[-1].append(v.get_idx_without_add(c) if v else c)

        # Transformer tokenizing
        config = self.transformer_config
        tokenizer = self.tokenizer
        xlnet = config_is(config, 'xlnet')
        roberta = config_is(config, 'roberta')
        pad_token = tokenizer.pad_token
        pad_token_id = tokenizer.convert_tokens_to_ids([pad_token])[0]
        cls_token = tokenizer.cls_token
        sep_token = tokenizer.sep_token
        max_seq_length = self.config.max_seq_length
        batch_forms = []
        batch_input_ids = []
        batch_input_mask = []
        batch_prefix_offset = []
        mask_p = self.mask_p
        if mask_p:
            batch_masked_offsets = []
            mask_token_id = tokenizer.mask_token_id
        for sent_idx, sent in enumerate(raw_batch[0]):
            batch_forms.append([self.form_vocab.get_idx_without_add(token) for token in sent])
            sent = adjust_tokens_for_transformers(sent)
            sent = sent[1:]  # remove <root> use [CLS] instead
            pad_label_idx = self.form_vocab.pad_idx
            input_ids, input_mask, segment_ids, prefix_mask = \
                convert_examples_to_features(sent,
                                             max_seq_length,
                                             tokenizer,
                                             cls_token_at_end=xlnet,
                                             # xlnet has a cls token at the end
                                             cls_token=cls_token,
                                             cls_token_segment_id=2 if xlnet else 0,
                                             sep_token=sep_token,
                                             sep_token_extra=roberta,
                                             # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                                             pad_on_left=xlnet,
                                             # pad on the left for xlnet
                                             pad_token_id=pad_token_id,
                                             pad_token_segment_id=4 if xlnet else 0,
                                             pad_token_label_id=pad_label_idx,
                                             do_padding=False)
            num_masks = sum(prefix_mask)
            # assert len(sent) == num_masks  # each token has a True subtoken
            if num_masks < len(sent):  # long sent gets truncated, +1 for root
                batch_forms[-1] = batch_forms[-1][:num_masks + 1]  # form
                raw_batch[-1][sent_idx] = raw_batch[-1][sent_idx][:num_masks + 1]  # head
                raw_batch[-2][sent_idx] = raw_batch[-2][sent_idx][:num_masks + 1]  # rel
                raw_batch[-3][sent_idx] = raw_batch[-3][sent_idx][:num_masks + 1]  # pos
            prefix_mask[0] = True  # <root> is now [CLS]
            prefix_offset = [idx for idx, m in enumerate(prefix_mask) if m]
            batch_input_ids.append(input_ids)
            batch_input_mask.append(input_mask)
            batch_prefix_offset.append(prefix_offset)
            if mask_p:
                if shuffle:
                    size = int(np.ceil(mask_p * len(prefix_offset[1:])))  # never mask [CLS]
                    mask_offsets = np.random.choice(np.arange(1, len(prefix_offset)), size, replace=False)
                    for offset in sorted(mask_offsets):
                        assert 0 < offset < len(input_ids)
                        # mask_word = raw_batch[0][sent_idx][offset]
                        # mask_prefix = tokenizer.convert_ids_to_tokens([input_ids[prefix_offset[offset]]])[0]
                        # assert mask_word.startswith(mask_prefix) or mask_prefix.startswith(
                        #     mask_word) or mask_prefix == "'", \
                        #     f'word {mask_word} prefix {mask_prefix} not match'  # could vs couldn
                        # mask_offsets.append(input_ids[offset]) # subword token
                        # mask_offsets.append(offset)  # form token
                        input_ids[prefix_offset[offset]] = mask_token_id  # mask prefix
                        # whole word masking, mask the rest of the word
                        for i in range(prefix_offset[offset] + 1, len(input_ids) - 1):
                            if prefix_mask[i]:
                                break
                            input_ids[i] = mask_token_id

                    batch_masked_offsets.append(sorted(mask_offsets))
                else:
                    batch_masked_offsets.append([0])  # No masking in prediction

        batch_forms = tf.keras.preprocessing.sequence.pad_sequences(batch_forms, padding='post',
                                                                    value=self.form_vocab.safe_pad_token_idx,
                                                                    dtype='int64')
        batch_input_ids = tf.keras.preprocessing.sequence.pad_sequences(batch_input_ids, padding='post',
                                                                        value=pad_token_id,
                                                                        dtype='int64')
        batch_input_mask = tf.keras.preprocessing.sequence.pad_sequences(batch_input_mask, padding='post',
                                                                         value=0,
                                                                         dtype='int64')
        batch_prefix_offset = tf.keras.preprocessing.sequence.pad_sequences(batch_prefix_offset, padding='post',
                                                                            value=0,
                                                                            dtype='int64')
        batch_heads = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[-2], padding='post',
                                                                    value=0,
                                                                    dtype='int64')
        batch_rels = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[-1], padding='post',
                                                                   value=self.rel_vocab.safe_pad_token_idx,
                                                                   dtype='int64')
        if mask_p:
            batch_masked_offsets = tf.keras.preprocessing.sequence.pad_sequences(batch_masked_offsets, padding='post',
                                                                                 value=pad_token_id,
                                                                                 dtype='int64')
        feats = (tf.constant(batch_input_ids, dtype='int64'), tf.constant(batch_input_mask, dtype='int64'),
                 tf.constant(batch_prefix_offset))
        if use_pos:
            batch_pos = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[1], padding='post',
                                                                      value=self.cpos_vocab.safe_pad_token_idx,
                                                                      dtype='int64')
            feats += (batch_pos,)
        yield (batch_forms, feats), \
              (batch_heads, batch_rels, batch_masked_offsets) if mask_p else (batch_heads, batch_rels)