示例#1
0
文件: conll.py 项目: zuoqy/HanLP
class CoNLL_SDP_Transform(CoNLLTransform):
    def __init__(self,
                 config: SerializableDict = None,
                 map_x=True,
                 map_y=True,
                 lower=True,
                 n_buckets=32,
                 n_tokens_per_batch=5000,
                 min_freq=2,
                 **kwargs) -> None:
        super().__init__(config, map_x, map_y, lower, n_buckets,
                         n_tokens_per_batch, min_freq, **kwargs)
        self.orphan_relation = ROOT

    def lock_vocabs(self):
        super().lock_vocabs()
        # heuristic to find the orphan relation
        for rel in self.rel_vocab.idx_to_token:
            if 'root' in rel.lower():
                self.orphan_relation = rel
                break

    def file_to_inputs(self, filepath: str, gold=True):
        assert gold, 'only support gold file for now'
        for i, sent in enumerate(read_conll(filepath)):
            prev_cells = None
            parsed_sent = []
            heads = []
            rels = []
            for j, cell in enumerate(sent):
                ID = cell[0]
                form = cell[1]
                cpos = cell[3]
                head = cell[6]
                deprel = cell[7]
                if prev_cells and ID != prev_cells[0]:  # found end of token
                    parsed_sent.append(
                        [prev_cells[1], prev_cells[2], heads, rels])
                    heads = []
                    rels = []
                heads.append(head)
                rels.append(deprel)
                prev_cells = [ID, form, cpos, head, deprel]
            parsed_sent.append([prev_cells[1], prev_cells[2], heads, rels])
            yield parsed_sent

    def fit(self, trn_path: str, **kwargs) -> int:
        self.form_vocab = Vocab()
        self.form_vocab.add(
            ROOT)  # make root the 2ed elements while 0th is pad, 1st is unk
        self.cpos_vocab = Vocab(pad_token=None, unk_token=None)
        self.rel_vocab = Vocab(pad_token=None, unk_token=None)
        num_samples = 0
        counter = Counter()
        for sent in self.file_to_samples(trn_path, gold=True):
            num_samples += 1
            for idx, (form, cpos, head, deprel) in enumerate(sent):
                if idx == 0:
                    root = form
                else:
                    counter[form] += 1
                self.cpos_vocab.add(cpos)
                self.rel_vocab.update(deprel)

        for token in [
                token for token, freq in counter.items()
                if freq >= self.config.min_freq
        ]:
            self.form_vocab.add(token)
        return num_samples

    def inputs_to_samples(self, inputs, gold=False):
        for sent in inputs:
            sample = []
            if self.config['lower']:
                for i, cell in enumerate(sent):
                    cell = list(sent[i])
                    cell[0] = cell[0].lower()
                    if not gold:
                        cell += [[0], [self.rel_vocab.safe_pad_token]]
                    sample.append(cell)
            # insert root word with arbitrary fields, anyway it will be masked
            form, cpos, head, deprel = sample[0]
            sample.insert(0, [self.bos, self.bos, [0], deprel])
            yield sample

    def samples_to_dataset(self,
                           samples: Generator,
                           map_x=None,
                           map_y=None,
                           batch_size=5000,
                           shuffle=None,
                           repeat=None,
                           drop_remainder=False,
                           prefetch=1,
                           cache=True) -> tf.data.Dataset:
        def generator():
            # custom bucketing, load corpus into memory
            corpus = list(
                x for x in (samples() if callable(samples) else samples))
            lengths = [1 + len(i) for i in corpus]
            if len(corpus) < 32:
                n_buckets = 1
            else:
                n_buckets = min(self.config.n_buckets, len(corpus))
            buckets = dict(zip(*kmeans(lengths, n_buckets)))
            sizes, buckets = zip(*[(size, bucket)
                                   for size, bucket in buckets.items()])
            # the number of chunks in each bucket, which is clipped by
            # range [1, len(bucket)]
            chunks = [
                min(len(bucket), max(round(size * len(bucket) / batch_size),
                                     1))
                for size, bucket in zip(sizes, buckets)
            ]
            range_fn = randperm if shuffle else arange
            for i in tolist(range_fn(len(buckets))):
                split_sizes = [(len(buckets[i]) - j - 1) // chunks[i] + 1
                               for j in range(chunks[i])]
                for batch_indices in tf.split(range_fn(len(buckets[i])),
                                              split_sizes):
                    indices = [buckets[i][j] for j in tolist(batch_indices)]
                    raw_batch = [[], [], [], []]
                    max_len = len(max([corpus[i] for i in indices], key=len))
                    for idx in indices:
                        arc = np.zeros((max_len, max_len), dtype=np.bool)
                        rel = np.zeros((max_len, max_len), dtype=np.int64)
                        for b in raw_batch[:2]:
                            b.append([])
                        for m, cells in enumerate(corpus[idx]):
                            for b, c, v in zip(
                                    raw_batch, cells,
                                [self.form_vocab, self.cpos_vocab]):
                                b[-1].append(v.get_idx_without_add(c))
                            for n, r in zip(cells[2], cells[3]):
                                arc[m, n] = True
                                rid = self.rel_vocab.get_idx_without_add(r)
                                if rid is None:
                                    logger.warning(
                                        f'Relation OOV: {r} not exists in train'
                                    )
                                    continue
                                rel[m, n] = rid
                        raw_batch[-2].append(arc)
                        raw_batch[-1].append(rel)
                    batch = []
                    for b, v in zip(raw_batch,
                                    [self.form_vocab, self.cpos_vocab]):
                        b = tf.keras.preprocessing.sequence.pad_sequences(
                            b,
                            padding='post',
                            value=v.safe_pad_token_idx,
                            dtype='int64')
                        batch.append(b)
                    batch += raw_batch[2:]
                    assert len(batch) == 4
                    yield (batch[0], batch[1]), (batch[2], batch[3])

        # for x in generator():
        #     print(len(x[-1][-1]))
        return super().samples_to_dataset(generator, False, False, 0, False,
                                          repeat, drop_remainder, prefetch,
                                          cache)

    def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
        types = (tf.int64, tf.int64), (tf.bool, tf.int64)
        shapes = ([None, None], [None, None]), ([None, None,
                                                 None], [None, None, None])
        values = (self.form_vocab.safe_pad_token_idx,
                  self.cpos_vocab.safe_pad_token_idx), (
                      False, self.rel_vocab.safe_pad_token_idx)
        return types, shapes, values

    def Y_to_outputs(self,
                     Y: Union[tf.Tensor, Tuple[tf.Tensor]],
                     gold=False,
                     inputs=None,
                     X=None) -> Iterable:
        arc_preds, rel_preds, mask = Y
        sents = []

        for arc_sent, rel_sent, length in zip(
                arc_preds, rel_preds, tf.math.count_nonzero(mask, axis=-1)):
            sent = []
            for arc, rel in zip(tolist(arc_sent[1:, 1:]),
                                tolist(rel_sent[1:, 1:])):
                ar = []
                for idx, (a, r) in enumerate(zip(arc, rel)):
                    if a:
                        ar.append((idx + 1, self.rel_vocab.idx_to_token[r]))
                if not ar:
                    # orphan
                    ar.append((0, self.orphan_relation))
                sent.append(ar)
            sents.append(sent)

        return sents

    def XY_to_inputs_outputs(self,
                             X: Union[tf.Tensor, Tuple[tf.Tensor]],
                             Y: Union[tf.Tensor, Tuple[tf.Tensor]],
                             gold=False,
                             inputs=None) -> Iterable:
        (words, feats, mask), (arc_preds, rel_preds) = X, Y
        xs = inputs
        ys = self.Y_to_outputs((arc_preds, rel_preds, mask))
        sents = []
        for x, y in zip(xs, ys):
            sent = CoNLLSentence()
            for idx, ((form, cpos), pred) in enumerate(zip(x, y)):
                sent.append(
                    CoNLLWord(id=idx + 1,
                              form=form,
                              cpos=cpos,
                              head=[p[0] for p in pred],
                              deprel=[p[1] for p in pred]))
            sents.append(sent)
        return sents
示例#2
0
class CoNLL_SDP_Transform(CoNLLTransform):
    def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32,
                 n_tokens_per_batch=5000, min_freq=2, **kwargs) -> None:
        super().__init__(config, map_x, map_y, lower, n_buckets, n_tokens_per_batch, min_freq, **kwargs)
        self.orphan_relation = ROOT

    def lock_vocabs(self):
        super().lock_vocabs()
        # heuristic to find the orphan relation
        for rel in self.rel_vocab.idx_to_token:
            if 'root' in rel.lower():
                self.orphan_relation = rel
                break

    def file_to_inputs(self, filepath: str, gold=True):
        assert gold, 'only support gold file for now'
        for i, sent in enumerate(read_conll(filepath)):
            prev_cells = None
            parsed_sent = []
            heads = []
            rels = []
            for j, cell in enumerate(sent):
                ID = cell[0]
                form = cell[1]
                cpos = cell[3]
                head = cell[6]
                deprel = cell[7]
                if prev_cells and ID != prev_cells[0]:  # found end of token
                    parsed_sent.append([prev_cells[1], prev_cells[2], heads, rels])
                    heads = []
                    rels = []
                heads.append(head)
                rels.append(deprel)
                prev_cells = [ID, form, cpos, head, deprel]
            parsed_sent.append([prev_cells[1], prev_cells[2], heads, rels])
            yield parsed_sent

    def fit(self, trn_path: str, **kwargs) -> int:
        self.form_vocab = Vocab()
        self.form_vocab.add(ROOT)  # make root the 2ed elements while 0th is pad, 1st is unk
        self.cpos_vocab = Vocab(pad_token=None, unk_token=None)
        self.rel_vocab = Vocab(pad_token=None, unk_token=None)
        num_samples = 0
        counter = Counter()
        for sent in self.file_to_samples(trn_path, gold=True):
            num_samples += 1
            for idx, (form, cpos, head, deprel) in enumerate(sent):
                if idx == 0:
                    root = form
                else:
                    counter[form] += 1
                self.cpos_vocab.add(cpos)
                self.rel_vocab.update(deprel)

        for token in [token for token, freq in counter.items() if freq >= self.config.min_freq]:
            self.form_vocab.add(token)
        return num_samples

    def inputs_to_samples(self, inputs, gold=False):
        for sent in inputs:
            sample = []
            if self.config['lower']:
                for i, cell in enumerate(sent):
                    cell = list(sent[i])
                    cell[0] = cell[0].lower()
                    if not gold:
                        cell += [[0], [self.rel_vocab.safe_pad_token]]
                    sample.append(cell)
            # insert root word with arbitrary fields, anyway it will be masked
            form, cpos, head, deprel = sample[0]
            sample.insert(0, [self.bos, self.bos, [0], deprel])
            yield sample

    def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
        types = (tf.int64, tf.int64), (tf.bool, tf.int64)
        shapes = ([None, None], [None, None]), ([None, None, None], [None, None, None])
        values = (self.form_vocab.safe_pad_token_idx, self.cpos_vocab.safe_pad_token_idx), (
            False, self.rel_vocab.safe_pad_token_idx)
        return types, shapes, values

    def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
        arc_preds, rel_preds, mask = Y
        sents = []

        for arc_sent, rel_sent, length in zip(arc_preds, rel_preds,
                                              tf.math.count_nonzero(mask, axis=-1)):
            sent = []
            for arc, rel in zip(tolist(arc_sent[1:, 1:]), tolist(rel_sent[1:, 1:])):
                ar = []
                for idx, (a, r) in enumerate(zip(arc, rel)):
                    if a:
                        ar.append((idx + 1, self.rel_vocab.idx_to_token[r]))
                if not ar:
                    # orphan
                    ar.append((0, self.orphan_relation))
                sent.append(ar)
            sents.append(sent)

        return sents

    def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]], Y: Union[tf.Tensor, Tuple[tf.Tensor]],
                             gold=False, inputs=None, conll=True) -> Iterable:
        (words, feats, mask), (arc_preds, rel_preds) = X, Y
        xs = inputs
        ys = self.Y_to_outputs((arc_preds, rel_preds, mask))
        sents = []
        for x, y in zip(xs, ys):
            sent = CoNLLSentence()
            for idx, ((form, cpos), pred) in enumerate(zip(x, y)):
                head = [p[0] for p in pred]
                deprel = [p[1] for p in pred]
                if conll:
                    sent.append(CoNLLWord(id=idx + 1, form=form, cpos=cpos, head=head, deprel=deprel))
                else:
                    sent.append([head, deprel])
            sents.append(sent)
        return sents

    def batched_inputs_to_batches(self, corpus, indices, shuffle=False):
        """
        Convert batched inputs to batches of samples

        Parameters
        ----------
        corpus : list
            A list of inputs
        indices : list
            A list of indices, each list belongs to a batch

        Returns
        -------
        None

        Yields
        -------
        tuple
            tuple of tf.Tensor
        """
        raw_batch = [[], [], [], []]
        max_len = len(max([corpus[i] for i in indices], key=len))
        for idx in indices:
            arc = np.zeros((max_len, max_len), dtype=np.bool)
            rel = np.zeros((max_len, max_len), dtype=np.int64)
            for b in raw_batch[:2]:
                b.append([])
            for m, cells in enumerate(corpus[idx]):
                for b, c, v in zip(raw_batch, cells,
                                   [self.form_vocab, self.cpos_vocab]):
                    b[-1].append(v.get_idx_without_add(c))
                for n, r in zip(cells[2], cells[3]):
                    arc[m, n] = True
                    rid = self.rel_vocab.get_idx_without_add(r)
                    if rid is None:
                        logger.warning(f'Relation OOV: {r} not exists in train')
                        continue
                    rel[m, n] = rid
            raw_batch[-2].append(arc)
            raw_batch[-1].append(rel)
        batch = []
        for b, v in zip(raw_batch, [self.form_vocab, self.cpos_vocab]):
            b = tf.keras.preprocessing.sequence.pad_sequences(b, padding='post',
                                                              value=v.safe_pad_token_idx,
                                                              dtype='int64')
            batch.append(b)
        batch += raw_batch[2:]
        assert len(batch) == 4
        yield (batch[0], batch[1]), (batch[2], batch[3])