def __init__(self, vocab_file, do_lower_case=True, simple_subword=False): self.vocab, self.inv_vocab = load_vocab_bpe(vocab_file) self.vocab_size = max(self.vocab.values()) + 1 self.simple_subword = simple_subword with open(vocab_file + '.dict', encoding='utf-8') as fin: bpe_dict = json.loads(fin.read()) self.encoder = Encoder.from_dict(bpe_dict) self.encoder.word_tokenizer = WhitespaceTokenizer().tokenize
def encode_data(data: List[Union[List[str], str]], vocabulary: dict, labels: Union[np.ndarray, List[int]] = None, max_length: int = None, use_bpe: bool = False, for_classification: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, List[np.ndarray]]]: if not use_bpe and type(data[0]) != list: logger.error("Tweets need to be tokenized for encoding.") raise ValueError("Tweets need to be tokenized for encoding.") max_length_given = max_length is not None if not max_length_given: max_length = 0 encoder = None if use_bpe: encoder = Encoder.from_dict(vocabulary) encoder.word_tokenizer = lambda x: TOKENIZER.tokenize(x) encoder.custom_tokenizer = True encoded_data = [] logger.info("Encoding data...") for tweet in tqdm(data): current_tweet = [] if use_bpe: current_tweet.extend(list(next(encoder.transform([tweet])))) else: for token_idx, token in enumerate(tweet): if max_length_given and token_idx == max_length: break if vocabulary.get(token): current_tweet.append(vocabulary.get(token)) else: current_tweet.append(vocabulary.get("<unk>")) if for_classification: start = (encoder.word_vocab if use_bpe else vocabulary).get("<start>") extract = (encoder.word_vocab if use_bpe else vocabulary).get("<extract>") current_tweet.insert(0, start) current_tweet.append(extract) encoded_data.append(current_tweet) if not max_length_given: max_length = max(max_length, len(current_tweet)) logger.info("Encoded data.") if not max_length_given and not for_classification: # add these two to account for <start> and <extract> tokens later max_length += 2 pad_value = encoder.word_vocab[encoder.PAD] if use_bpe else 0 encoded_data = pad_sequences(encoded_data, maxlen=max_length, padding="post", value=pad_value) if labels is not None: encoded_targets = np.concatenate([encoded_data[:, 1:], np.full((len(encoded_data), 1), pad_value)], axis=1) encoded_targets = np.reshape(encoded_targets, encoded_targets.shape + (1,)) return encoded_data, [encoded_targets, np.array(labels)] return encoded_data
def __init__(self, data_path="data/mt_corpus_ts.txt", vocab_path='data/vocab_bpe.txt', timesteps_max=100, batch_size=256, word_dropout_ratio=0.75, simple_subword=False): self.data_path = data_path self.simple_subword = simple_subword self.word_dropout_ratio = word_dropout_ratio self.timesteps_max = timesteps_max #load vocab self.char2idx, self.idx2char = load_vocab_bpe(vocab_path) with open(vocab_path + '.dict', encoding='utf-8') as fin: bpe_dict = json.loads(fin.read()) self.encoder = Encoder.from_dict(bpe_dict) self.encoder.word_tokenizer = WhitespaceTokenizer().tokenize self.batch_size = batch_size
def read_text_data_bpe(num_samples=3000, data_path="data/mt_corpus_ts.txt", vocab_path='data/vocab_bpe.txt', word_dropout_ratio=0.75, simple_subword=False): """ :param num_samples: :param data_path: :return: timesteps_max, char2id, id2char, x, x_decoder #enc_tokens, characters, """ rng = random.Random(88) # vectorize the data timesteps_max = 100 input_texts = [] char2idx, idx2char = load_vocab_bpe(vocab_path) with open(vocab_path + '.dict', encoding='utf-8') as fin: bpe_dict = json.loads(fin.read()) encoder = Encoder.from_dict(bpe_dict) encoder.word_tokenizer = WhitespaceTokenizer().tokenize lines = [] print('read data from ', data_path) line_num = 0 with open(data_path, encoding='utf-8') as fin: for line in fin: line_num += 1 if line_num == 1: continue if line_num % 100000 == 0: print(line_num) if line_num > num_samples + 200: break #tuples = line.strip().split('\t') #zh = tuples[1] zh = line terms = zh.split() if len(terms) <= timesteps_max - 2: terms = encoder.tokenize(zh) # terms = [term for term in terms if term != encoder.EOW and term != encoder.SOW] terms = remove_seow(terms, encoder, simple_subword) if len(terms) <= timesteps_max - 2: lines.append(terms) for line in lines[:min(num_samples, len(lines) - 1)]: input_text = line input_text.append("<eos>") input_texts.append(input_text) # for char in input_text: # if char not in input_characters: # input_characters.add(char) # # input_characters = sorted(list(input_characters)) num_encoder_tokens = max(idx2char.keys()) + 1 max_encoder_seq_length = timesteps_max #max([len(txt) for txt in input_texts]) + 1 print("Number of samples:", len(input_texts)) print("Number of unique input tokens:", num_encoder_tokens) print("Max sequence length for inputs:", max_encoder_seq_length) # input_token_index = dict([(char, i) for i, char in enumerate(input_characters)]) # reverse_input_char_index = dict((i, char) for char, i in input_token_index.items()) encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length), dtype="int32") decoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length), dtype="int32") decoder_output_data = np.zeros((len(input_texts), max_encoder_seq_length), dtype="int32") # encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_encoder_tokens), dtype="float32") # decoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_encoder_tokens), dtype="float32") for i, input_text in enumerate(input_texts): decoder_input_data[i, 0] = char2idx["<sos>"] # decoder_input_data[i, 0, char2idx["<sos>"]] = 1.0 for t, char in enumerate(input_text): idx = char2idx[char] if char in char2idx else char2idx["<unk>"] idx_mask = idx if rng.random() < word_dropout_ratio: #TODO 添加一个新的单词<mask>,而不是使用<unk> if rng.random() < 0.9: idx_mask = char2idx["<unk>"] else: # 10% of the time, replace with random word idx_mask = rng.randint(0, num_encoder_tokens - 1) encoder_input_data[i, t] = idx_mask decoder_output_data[i, t] = idx decoder_input_data[i, t + 1] = idx_mask # encoder_input_data[i, t, idx] = 1.0 # decoder_input_data[i, t + 1, idx ] = 1.0 return max_encoder_seq_length, char2idx, idx2char, encoder_input_data, decoder_input_data, decoder_output_data