Exemple #1
0
    def _get_dataset_path(dataset_name):

        default_cache_path = get_cache_path()
        url = _get_dataset_url(dataset_name)
        output_dir = cached_path(url_or_filename=url,
                                 cache_dir=default_cache_path,
                                 name='dataset')

        return output_dir
Exemple #2
0
    def _get_model(self, model_type):
        if model_type == 'base':
            url = 'http://212.129.155.247/fasthan/fasthan_base.zip'
        elif model_type == 'large':
            url = 'http://212.129.155.247/fasthan/fasthan_large.zip'
        else:
            raise ValueError("model_type can only be base or large.")

        model_dir = cached_path(url, name='fasthan')
        return model_dir
Exemple #3
0
    def _get_model(self, model_type):

        #首先检查本地目录中是否已缓存模型,若没有缓存则下载。

        if model_type == 'base':
            url = 'http://d1qbdbol06y129.cloudfront.net/fasthan_base.zip'
        else:
            raise ValueError("model_type can only be base.")

        model_dir = cached_path(url, name='fasthan')
        return model_dir
Exemple #4
0
def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'):
    if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
        model_url = _get_embedding_url('bert', model_dir_or_name.lower())
        model_dir = cached_path(model_url, name='embedding')
        # 检查是否存在
    elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
        model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
    else:
        logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
        raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
    return str(model_dir)
Exemple #5
0
    def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True,
                 init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs):
        """

        :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
        :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个
            以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
            如果输入为None则使用embedding_dim的维度随机初始化一个embedding。
        :param int embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。
        :param bool requires_grad: 是否需要gradient. 默认为True
        :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法, 传入的方法应该接受一个tensor,并
            inplace地修改其值。
        :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
            为大写的词语开辟一个vector表示,则将lower设置为False。
        :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
        :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
        :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。
        :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
        :param dict kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中找到的词语使用normalize。
        """
        super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
        if embedding_dim > 0:
            model_dir_or_name = None

        # 得到cache_path
        if model_dir_or_name is None:
            assert embedding_dim >= 1, "The dimension of embedding should be larger than 1."
            embedding_dim = int(embedding_dim)
            model_path = None
        elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
            model_url = _get_embedding_url('static', model_dir_or_name.lower())
            model_path = cached_path(model_url, name='embedding')
            # 检查是否存在
        elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))):
            model_path = os.path.abspath(os.path.expanduser(model_dir_or_name))
        elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
            model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt')
        else:
            raise ValueError(f"Cannot recognize {model_dir_or_name}.")

        # 根据min_freq缩小vocab
        truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq)
        if truncate_vocab:
            truncated_vocab = deepcopy(vocab)
            truncated_vocab.min_freq = min_freq
            truncated_vocab.word2idx = None
            if lower:  # 如果有lower,将大小写的的freq需要同时考虑到
                lowered_word_count = defaultdict(int)
                for word, count in truncated_vocab.word_count.items():
                    lowered_word_count[word.lower()] += count
                for word in truncated_vocab.word_count.keys():
                    word_count = truncated_vocab.word_count[word]
                    if lowered_word_count[word.lower()] >= min_freq and word_count < min_freq:
                        truncated_vocab.add_word_lst([word] * (min_freq - word_count),
                                                     no_create_entry=truncated_vocab._is_word_no_create_entry(word))

            # 只限制在train里面的词语使用min_freq筛选
            if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None:
                for word in truncated_vocab.word_count.keys():
                    if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word] < min_freq:
                        truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]),
                                                     no_create_entry=True)
            truncated_vocab.build_vocab()
            truncated_words_to_words = torch.arange(len(vocab)).long()
            for word, index in vocab:
                truncated_words_to_words[index] = truncated_vocab.to_index(word)
            logger.info(
                f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.")
            vocab = truncated_vocab

        self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False)
        # 读取embedding
        if lower:
            lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
            for word, index in vocab:
                if vocab._is_word_no_create_entry(word):
                    lowered_vocab.add_word(word.lower(), no_create_entry=True)
                else:
                    lowered_vocab.add_word(word.lower())  # 先加入需要创建entry的
            logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
                        f"unique lowered words.")
            if model_path:
                embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
            else:
                embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
                self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
            if lowered_vocab.unknown:
                unknown_idx = lowered_vocab.unknown_idx
            else:
                unknown_idx = embedding.size(0) - 1  # 否则是最后一个为unknow
                self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
            words_to_words = torch.full((len(vocab),), fill_value=unknown_idx).long()
            for word, index in vocab:
                if word not in lowered_vocab:
                    word = word.lower()
                    if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word):
                        continue  # 如果不需要创建entry,已经默认unknown了
                words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
            self.register_buffer('words_to_words', words_to_words)
            self._word_unk_index = lowered_vocab.unknown_idx  # 替换一下unknown的index
        else:
            if model_path:
                embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method)
            else:
                embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
                self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
        if not self.only_norm_found_vector and normalize:
            embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)

        if truncate_vocab:
            for i in range(len(truncated_words_to_words)):
                index_in_truncated_vocab = truncated_words_to_words[i]
                truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab]
            del self.words_to_words
            self.register_buffer('words_to_words', truncated_words_to_words)
        self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
                                      padding_idx=vocab.padding_idx,
                                      max_norm=None, norm_type=2, scale_grad_by_freq=False,
                                      sparse=False, _weight=embedding)
        self._embed_size = self.embedding.weight.size(1)
        self.requires_grad = requires_grad
        self.dropout = MyDropout(dropout)
Exemple #6
0
    def process(
        self,
        paths: Union[str, Dict[str, str]],
        dataset_name: str = None,
        to_lower=False,
        seq_len_type: str = None,
        bert_tokenizer: str = None,
        cut_text: int = None,
        get_index=True,
        auto_pad_length: int = None,
        auto_pad_token: str = '<pad>',
        set_input: Union[list, str, bool] = True,
        set_target: Union[list, str, bool] = True,
        concat: Union[str, list, bool] = None,
    ) -> DataBundle:
        """
        :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
            则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
            对应的全路径文件名。
        :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义
            这个数据集的名字,如果不定义则默认为train。
        :param bool to_lower: 是否将文本自动转为小写。默认值为False。
        :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` :
            提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和
            attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len
        :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
        :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。
        :param bool get_index: 是否需要根据词表将文本转为index
        :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad
        :param str auto_pad_token: 自动pad的内容
        :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
            则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
            于此同时其他field不会被设置为input。默认值为True。
        :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。
        :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。
            如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
            传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
        :return:
        """
        if isinstance(set_input, str):
            set_input = [set_input]
        if isinstance(set_target, str):
            set_target = [set_target]
        if isinstance(set_input, bool):
            auto_set_input = set_input
        else:
            auto_set_input = False
        if isinstance(set_target, bool):
            auto_set_target = set_target
        else:
            auto_set_target = False
        if isinstance(paths, str):
            if os.path.isdir(paths):
                path = {
                    n: os.path.join(paths, self.paths[n])
                    for n in self.paths.keys()
                }
            else:
                path = {
                    dataset_name if dataset_name is not None else 'train':
                    paths
                }
        else:
            path = paths

        data_info = DataBundle()
        for data_name in path.keys():
            data_info.datasets[data_name] = self._load(path[data_name])

        for data_name, data_set in data_info.datasets.items():
            if auto_set_input:
                data_set.set_input(Const.INPUTS(0), Const.INPUTS(1))
            if auto_set_target:
                if Const.TARGET in data_set.get_field_names():
                    data_set.set_target(Const.TARGET)

        if to_lower:
            for data_name, data_set in data_info.datasets.items():
                data_set.apply(
                    lambda x: [w.lower() for w in x[Const.INPUTS(0)]],
                    new_field_name=Const.INPUTS(0),
                    is_input=auto_set_input)
                data_set.apply(
                    lambda x: [w.lower() for w in x[Const.INPUTS(1)]],
                    new_field_name=Const.INPUTS(1),
                    is_input=auto_set_input)

        if bert_tokenizer is not None:
            if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR:
                PRETRAIN_URL = _get_base_url('bert')
                model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
                model_url = PRETRAIN_URL + model_name
                model_dir = cached_path(model_url)
                # 检查是否存在
            elif os.path.isdir(bert_tokenizer):
                model_dir = bert_tokenizer
            else:
                raise ValueError(
                    f"Cannot recognize BERT tokenizer from {bert_tokenizer}.")

            words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
            with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f:
                lines = f.readlines()
            lines = [line.strip() for line in lines]
            words_vocab.add_word_lst(lines)
            words_vocab.build_vocab()

            tokenizer = BertTokenizer.from_pretrained(model_dir)

            for data_name, data_set in data_info.datasets.items():
                for fields in data_set.get_field_names():
                    if Const.INPUT in fields:
                        data_set.apply(
                            lambda x: tokenizer.tokenize(' '.join(x[fields])),
                            new_field_name=fields,
                            is_input=auto_set_input)

        if isinstance(concat, bool):
            concat = 'default' if concat else None
        if concat is not None:
            if isinstance(concat, str):
                CONCAT_MAP = {
                    'bert': ['[CLS]', '[SEP]', '', '[SEP]'],
                    'default': ['', '<sep>', '', '']
                }
                if concat.lower() in CONCAT_MAP:
                    concat = CONCAT_MAP[concat]
                else:
                    concat = 4 * [concat]
            assert len(concat) == 4, \
                f'Please choose a list with 4 symbols which at the beginning of first sentence ' \
                f'the end of first sentence, the begin of second sentence, and the end of second' \
                f'sentence. Your input is {concat}'

            for data_name, data_set in data_info.datasets.items():
                data_set.apply(
                    lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[
                        1]] + [concat[2]] + x[Const.INPUTS(1)] + [concat[3]],
                    new_field_name=Const.INPUT)
                data_set.apply(
                    lambda x: [w for w in x[Const.INPUT] if len(w) > 0],
                    new_field_name=Const.INPUT,
                    is_input=auto_set_input)

        if seq_len_type is not None:
            if seq_len_type == 'seq_len':  #
                for data_name, data_set in data_info.datasets.items():
                    for fields in data_set.get_field_names():
                        if Const.INPUT in fields:
                            data_set.apply(lambda x: len(x[fields]),
                                           new_field_name=fields.replace(
                                               Const.INPUT, Const.INPUT_LEN),
                                           is_input=auto_set_input)
            elif seq_len_type == 'mask':
                for data_name, data_set in data_info.datasets.items():
                    for fields in data_set.get_field_names():
                        if Const.INPUT in fields:
                            data_set.apply(lambda x: [1] * len(x[fields]),
                                           new_field_name=fields.replace(
                                               Const.INPUT, Const.INPUT_LEN),
                                           is_input=auto_set_input)
            elif seq_len_type == 'bert':
                for data_name, data_set in data_info.datasets.items():
                    if Const.INPUT not in data_set.get_field_names():
                        raise KeyError(
                            f'Field ``{Const.INPUT}`` not in {data_name} data set: '
                            f'got {data_set.get_field_names()}')
                    data_set.apply(lambda x: [0] *
                                   (len(x[Const.INPUTS(0)]) + 2) + [1] *
                                   (len(x[Const.INPUTS(1)]) + 1),
                                   new_field_name=Const.INPUT_LENS(0),
                                   is_input=auto_set_input)
                    data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
                                   new_field_name=Const.INPUT_LENS(1),
                                   is_input=auto_set_input)

        if auto_pad_length is not None:
            cut_text = min(
                auto_pad_length,
                cut_text if cut_text is not None else auto_pad_length)

        if cut_text is not None:
            for data_name, data_set in data_info.datasets.items():
                for fields in data_set.get_field_names():
                    if (Const.INPUT
                            in fields) or ((Const.INPUT_LEN in fields) and
                                           (seq_len_type != 'seq_len')):
                        data_set.apply(lambda x: x[fields][:cut_text],
                                       new_field_name=fields,
                                       is_input=auto_set_input)

        data_set_list = [d for n, d in data_info.datasets.items()]
        assert len(data_set_list) > 0, f'There are NO data sets in data info!'

        if bert_tokenizer is None:
            words_vocab = Vocabulary(padding=auto_pad_token)
            words_vocab = words_vocab.from_dataset(
                *[d for n, d in data_info.datasets.items() if 'train' in n],
                field_name=[
                    n for n in data_set_list[0].get_field_names()
                    if (Const.INPUT in n)
                ],
                no_create_entry_dataset=[
                    d for n, d in data_info.datasets.items()
                    if 'train' not in n
                ])
        target_vocab = Vocabulary(padding=None, unknown=None)
        target_vocab = target_vocab.from_dataset(
            *[d for n, d in data_info.datasets.items() if 'train' in n],
            field_name=Const.TARGET)
        data_info.vocabs = {
            Const.INPUT: words_vocab,
            Const.TARGET: target_vocab
        }

        if get_index:
            for data_name, data_set in data_info.datasets.items():
                for fields in data_set.get_field_names():
                    if Const.INPUT in fields:
                        data_set.apply(
                            lambda x:
                            [words_vocab.to_index(w) for w in x[fields]],
                            new_field_name=fields,
                            is_input=auto_set_input)

                if Const.TARGET in data_set.get_field_names():
                    data_set.apply(
                        lambda x: target_vocab.to_index(x[Const.TARGET]),
                        new_field_name=Const.TARGET,
                        is_input=auto_set_input,
                        is_target=auto_set_target)

        if auto_pad_length is not None:
            if seq_len_type == 'seq_len':
                raise RuntimeError(
                    f'the sequence will be padded with the length {auto_pad_length}, '
                    f'so the seq_len_type cannot be `{seq_len_type}`!')
            for data_name, data_set in data_info.datasets.items():
                for fields in data_set.get_field_names():
                    if Const.INPUT in fields:
                        data_set.apply(
                            lambda x: x[fields] +
                            [words_vocab.to_index(words_vocab.padding)] *
                            (auto_pad_length - len(x[fields])),
                            new_field_name=fields,
                            is_input=auto_set_input)
                    elif (Const.INPUT_LEN
                          in fields) and (seq_len_type != 'seq_len'):
                        data_set.apply(lambda x: x[fields] + [0] *
                                       (auto_pad_length - len(x[fields])),
                                       new_field_name=fields,
                                       is_input=auto_set_input)

        for data_name, data_set in data_info.datasets.items():
            if isinstance(set_input, list):
                data_set.set_input(*[
                    inputs for inputs in set_input
                    if inputs in data_set.get_field_names()
                ])
            if isinstance(set_target, list):
                data_set.set_target(*[
                    target for target in set_target
                    if target in data_set.get_field_names()
                ])

        return data_info