Ejemplo n.º 1
0
 def is_unk(self, token: str) -> bool:
     """
     Check whether a token is covered by the vocabulary
     :param token:
     :return: True if covered, False otherwise
     """
     return self.stoi[token] == DEFAULT_UNK_ID()
def build_vocab(field: str,
                max_size: int,
                min_freq: int,
                dataset: object,
                vocab_file: str = None) -> Vocabulary:
    """
    Builds vocabulary for a torchtext `field` from given`dataset` or
    `vocab_file`.

    :param field: attribute e.g. "src"
    :param max_size: maximum size of vocabulary
    :param min_freq: minimum frequency for an item to be included
    :param dataset: dataset to load data for field from
    :param vocab_file: file to store the vocabulary,
        if not None, load vocabulary from here
    :return: Vocabulary created from either `dataset` or `vocab_file`
    """

    if vocab_file is not None:
        # load it from file
        vocab = Vocabulary(file=vocab_file)
    else:
        # create newly
        def filter_min(counter: Counter, min_freq: int):
            """ Filter counter by min frequency """
            filtered_counter = Counter(
                {t: c
                 for t, c in counter.items() if c >= min_freq})
            return filtered_counter

        def sort_and_cut(counter: Counter, limit: int):
            """ Cut counter to most frequent,
            sorted numerically and alphabetically"""
            # sort by frequency, then alphabetically
            tokens_and_frequencies = sorted(counter.items(),
                                            key=lambda tup: tup[0])
            tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
            vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]]
            return vocab_tokens

        vocab_tokens = []
        for label in dataset[1]:
            for word in label:
                vocab_tokens.append(word)

        # counter = Counter(vocab_tokens)
        # if min_freq > -1:
        #     counter = filter_min(counter, min_freq)
        # vocab_tokens = sort_and_cut(counter, max_size)
        assert len(vocab_tokens) <= max_size or max_size == -1
        vocab = Vocabulary(tokens=vocab_tokens)
        print(vocab)
        assert len(vocab) <= max_size + len(vocab.specials) or max_size == -1
        assert vocab.itos[DEFAULT_UNK_ID()] == UNK_TOKEN

    # check for all except for UNK token whether they are OOVs
    for s in vocab.specials[1:]:
        assert not vocab.is_unk(s)

    return vocab
Ejemplo n.º 3
0
def build_vocab(field, max_size, min_freq, data, vocab_file=None):
    """
    Builds vocabulary for a torchtext `field`

    :param field:
    :param max_size:
    :param min_freq:
    :param data:
    :param vocab_file:
    :return:
    """

    # special symbols
    specials = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]

    if vocab_file is not None:
        # load it from file
        vocab = Vocabulary(file=vocab_file)
        vocab.add_tokens(specials)
    else:
        # create newly
        def filter_min(counter, min_freq):
            """ Filter counter by min frequency """
            filtered_counter = Counter({t: c for t, c in counter.items()
                                   if c >= min_freq})
            return filtered_counter

        def sort_and_cut(counter, limit):
            """ Cut counter to most frequent,
            sorted numerically and alphabetically"""
            # sort by frequency, then alphabetically
            tokens_and_frequencies = sorted(counter.items(),
                                            key=lambda tup: tup[0])
            tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
            vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]]
            return vocab_tokens

        tokens = []
        for i in data.examples:
            if field == "src":
                tokens.extend(i.src)
            elif field == "trg":
                tokens.extend(i.trg)

        counter = Counter(tokens)
        if min_freq > -1:
            counter = filter_min(counter, min_freq)
        vocab_tokens = specials + sort_and_cut(counter, max_size)
        assert vocab_tokens[DEFAULT_UNK_ID()] == UNK_TOKEN
        assert len(vocab_tokens) <= max_size + len(specials)
        vocab = Vocabulary(tokens=vocab_tokens)

    # check for all except for UNK token whether they are OOVs
    for s in specials[1:]:
        assert not vocab.is_unk(s)

    return vocab
Ejemplo n.º 4
0
def build_vocab(fields: Union[str, Tuple[str]],
                max_size: int,
                min_freq: int,
                dataset: Union[Dataset, Tuple[Dataset]],
                vocab_file: str = None) -> Vocabulary:
    """
    Builds vocabulary for a torchtext `field` from given`dataset` or
    `vocab_file` or tuple of 'dataset'.

    :param fields: attribute e.g. "src", or Tuple of attributes (kb task), e.g. ("src", "kbsrc")
    :param max_size: maximum size of vocabulary
    :param min_freq: minimum frequency for an item to be included
    :param dataset: dataset to load data for field from
    :param vocab_file: file to store the vocabulary,
        if not None, load vocabulary from here
    :return: Vocabulary created from either `dataset` or `vocab_file`
    """

    if vocab_file is not None:
        # load it from file
        vocab = Vocabulary(file=vocab_file)
    else:
        # create newly
        def filter_min(counter: Counter, min_freq: int):
            """ Filter counter by min frequency """
            filtered_counter = Counter(
                {t: c
                 for t, c in counter.items() if c >= min_freq})
            return filtered_counter

        def sort_and_cut(counter: Counter, limit: int):
            """ Cut counter to most frequent,
            sorted numerically and alphabetically"""
            # sort by frequency, then alphabetically
            tokens_and_frequencies = sorted(counter.items(),
                                            key=lambda tup: tup[0])
            tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
            vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]]
            return vocab_tokens

        if isinstance(dataset, tuple):
            assert len(
                dataset
            ) == 2, "build_vocab currently only supports looking at either just 1 dataset or a tuple of 2 datasets"
            dataset, kb_dataset = dataset
        else:
            kb_dataset = None
        if isinstance(fields, tuple):
            assert len(
                fields
            ) == 2, "build_vocab currently only supports looking at either one field (default joeynmt) or two fields (kb_task)"
        else:
            assert isinstance(fields, str), fields
            fields = (fields, )
            print(f"data for field={fields}")

        print(f"processing data for fields={fields}")

        # FIXME greedily trying to get all we can; TODO match fields with dataset fields
        warning = {f: 0 for f in fields}
        tokens = []
        for ex in dataset.examples + kb_dataset.examples:
            for f in fields:
                try:
                    tokens.extend(getattr(ex, f))
                except AttributeError:
                    warning[f] += 1
                    continue  # fail in silence
        print(
            f"processed data for fields={fields} with {[(f, warning[f]) for f in fields]} Attribute Errors per field\n"
        )

        counter = Counter(tokens)
        if min_freq > -1:
            counter = filter_min(counter, min_freq)
        vocab_tokens = sort_and_cut(counter, max_size)
        assert len(vocab_tokens) <= max_size

        vocab = Vocabulary(tokens=vocab_tokens)
        assert len(vocab) <= max_size + len(vocab.specials)
        assert vocab.itos[DEFAULT_UNK_ID()] == UNK_TOKEN

    # check for all except for UNK token whether they are OOVs
    for s in vocab.specials[1:]:
        assert not vocab.is_unk(s)

    return vocab
Ejemplo n.º 5
0
def build_vocab(field: str,
                max_size: int,
                min_freq: int = 0,
                dataset: Dataset = None,
                vocab_file: str = None,
                encoding=DEFAULT_ENCODING,
                lower=False) -> Vocabulary:
    """
    Builds vocabulary for a torchtext `field` from given`dataset` or
    `vocab_file`.
    :param field: attribute e.g. "src"
    :param max_size: maximum size of vocabulary
    :param min_freq: minimum frequency for an item to be included
    :param dataset: dataset to load data for field from
    :param vocab_file: file to store the vocabulary,
        if not None, load vocabulary from here
    :return: Vocabulary created from either `dataset` or `vocab_file`
    """

    if vocab_file is not None:
        # load it from file
        vocab = Vocabulary(file=vocab_file, encoding=encoding, lower=lower)
    elif dataset is not None:
        # create newly
        def filter_min(counter: Counter, min_freq: int):
            """ Filter counter by min frequency """
            filtered_counter = Counter(
                {t: c
                 for t, c in counter.items() if c >= min_freq})
            return filtered_counter

        def sort_and_cut(counter: Counter, limit: int):
            """ Cut counter to most frequent,
            sorted numerically and alphabetically"""
            # sort by frequency, then alphabetically
            tokens_and_frequencies = sorted(counter.items(),
                                            key=lambda tup: tup[0])
            tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
            vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]]
            return vocab_tokens

        tokens = []
        for i in dataset.examples:
            if field == "src":
                tokens.extend(i.src)
            elif field == "trg":
                tokens.extend(i.trg)

        counter = Counter(tokens)
        if min_freq > 0:
            counter = filter_min(counter, min_freq)
        vocab_tokens = sort_and_cut(counter, max_size)
        assert len(vocab_tokens) <= max_size

        vocab = Vocabulary(tokens=vocab_tokens)
        assert len(vocab) <= max_size + len(vocab.specials)
        assert vocab.itos[DEFAULT_UNK_ID()] == UNK_TOKEN

    else:
        raise ValueError(f"要求参数dataset或者vocab_file 至少有一个不为空!")

    # check for all except for UNK token whether they are OOVs
    for s in vocab.specials[1:]:
        assert not vocab.is_unk(s)

    return vocab