示例#1
0
    def process_line(self, line, vocab, max_len, need_raw=False):
        if self.model_config.tokenizer == 'split':
            words = line.split()
        elif self.model_config.tokenizer == 'nltk':
            words = word_tokenize(line)
        else:
            raise Exception('Unknown tokenizer.')

        words = [Vocab.process_word(word, self.model_config)
                 for word in words]
        if need_raw:
            words_raw = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
        else:
            words_raw = None

        if self.model_config.subword_vocab_size > 0:
            words = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
            words = vocab.encode(' '.join(words))
        else:
            words = [vocab.encode(word) for word in words]
            words = ([self.vocab_simple.encode(constant.SYMBOL_START)] + words +
                     [self.vocab_simple.encode(constant.SYMBOL_END)])

        if self.model_config.subword_vocab_size > 0:
            pad_id = vocab.encode(constant.SYMBOL_PAD)
        else:
            pad_id = [vocab.encode(constant.SYMBOL_PAD)]

        if len(words) < max_len:
            num_pad = max_len - len(words)
            words.extend(num_pad * pad_id)
        else:
            words = words[:max_len]

        return words, words_raw
def process_line(line,
                 vocab,
                 max_len,
                 model_config,
                 need_raw=False,
                 lower_case=True,
                 base_line=None):
    if lower_case:
        line = line.lower()
    if type(line) == bytes:
        line = str(line, 'utf-8')

    if model_config.tokenizer == 'split':
        words = line.split()
    elif model_config.tokenizer == 'nltk':
        words = word_tokenize(line)
    else:
        raise Exception('Unknown tokenizer.')

    words = [Vocab.process_word(word, model_config) for word in words]
    if need_raw:
        words_raw = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
    else:
        words_raw = None

    if model_config.subword_vocab_size > 0 or 'bert_token' in model_config.bert_mode:
        words = [constant.SYMBOL_START] + words + [constant.SYMBOL_END]
        words = vocab.encode(' '.join(words))
    else:
        words = [vocab.encode(word) for word in words]
        words = ([vocab.encode(constant.SYMBOL_START)] + words +
                 [vocab.encode(constant.SYMBOL_END)])

    if model_config.subword_vocab_size > 0 or 'bert_token' in model_config.bert_mode:
        pad_id = vocab.encode(constant.SYMBOL_PAD)
    else:
        pad_id = [vocab.encode(constant.SYMBOL_PAD)]

    if len(words) < max_len:
        num_pad = max_len - len(words)
        words.extend(num_pad * pad_id)
    else:
        words = words[:max_len]

    obj = {}
    if model_config.subword_vocab_size and 'seg' in model_config.seg_mode:
        obj['segment_idxs'] = get_segment_idx(words, vocab)
    elif model_config.subword_vocab_size and 'cp' in model_config.seg_mode:
        populate_freq('/zfs1/hdaqing/saz31/dataset/vocab/all.vocab')
        obj['segment_idxs'] = get_segment_copy_idx(words,
                                                   freq,
                                                   vocab,
                                                   base_line=base_line)

    return words, words_raw, obj