def main(): args = cmd_args() outdir = args.o if args.o else os.path.dirname(args.i) target_special_tokens, subtoken_special_tokens = get_special_tokens( args.preset) with tempfile.TemporaryDirectory() as tmp_dir: targets_file = os.path.join(tmp_dir, "labels.txt") subtokens_file = os.path.join(tmp_dir, "subtokens.txt") print(f"Creating training files for BPE") create_bpe_training_file(args.i, targets_file, subtokens_file) if args.preset == Preset.variable: print("Variable preset") subtoken_tokenizer = SentencePieceBPETokenizer() target_tokenizer = SentencePieceBPETokenizer() print(f"Training subtoken tokenizer") subtoken_tokenizer.add_special_tokens(subtoken_special_tokens) print(f"Training target tokenizer") target_tokenizer.add_special_tokens(target_special_tokens) target_tokenizer.train(files=[targets_file], vocab_size=args.target_vocab) subtoken_tokenizer.train(files=[subtokens_file], vocab_size=args.subtoken_vocab) target_tokenizer.save(outdir, "target.bpe") subtoken_tokenizer.save(outdir, "subtoken.bpe")
class BPEVocabulary(Vocabulary): """ Represents a SentencePiece vocabulary for c2s. """ def __init__(self, args: Namespace): super().__init__() self.target_encoder = SentencePieceBPETokenizer( args.target_vocab, args.target_merges) self.subtoken_encoder = SentencePieceBPETokenizer( args.subtoken_vocab, args.subtoken_merges) # self.target_encoder.add_special_tokens( # [self.EOS_TOKEN, self.SOS_TOKEN, self.PAD_TOKEN] # ) # self.subtoken_encoder.add_special_tokens([self.EOS_TOKEN, self.PAD_TOKEN]) with open(args.node_dict, "rb") as f: self.node_to_index = pickle.load(f) self.index_to_node = {v: k for k, v in self.node_to_index.items()} def target_vocab_size(self): # print(self.target_encoder.num_special_tokens_to_add()) return self.target_encoder.get_vocab_size() + 4 def node_vocab_size(self): # print(self.target_encoder.num_special_tokens_to_add()) return len(self.node_to_index) + 2 def terminal_vocab_size(self): return self.subtoken_encoder.get_vocab_size() + 4 def add_special_target_token(self, token: str): self.target_encoder.add_special_tokens([token]) def add_special_terminal_token(self, token: str): self.subtoken_encoder.add_special_tokens([token]) def encode_node(self, token_or_tokens): if isinstance(token_or_tokens, str): return self.node_to_index.get(token_or_tokens, self.node_to_index[self.UNK_TOKEN]) else: return list(map(self.encode_node, token_or_tokens)) def decode_node(self, index_or_indices): if isinstance(index_or_indices, int): return self.index_to_node[index_or_indices] else: return list(map(self.decode_node, index_or_indices)) def encode_target(self, token_or_tokens): if isinstance(token_or_tokens, str): return self.target_encoder.token_to_id(token_or_tokens) else: return self.target_encoder.encode(" ".join(token_or_tokens)).ids def decode_target(self, index_or_indices): if isinstance(index_or_indices, int): return self.target_encoder.id_to_token(index_or_indices) else: return self.target_encoder.decode(index_or_indices) def encode_terminal(self, token_or_tokens): if isinstance(token_or_tokens, str): return self.subtoken_encoder.token_to_id(token_or_tokens) else: return self.subtoken_encoder.encode(" ".join(token_or_tokens)).ids def decode_terminal(self, index_or_indices): if isinstance(index_or_indices, int): return self.terminal_encoder.id_to_token(index_or_indices) else: return self.terminal_encoder.decode(index_or_indices)