def main(): parser = argparse.ArgumentParser() parser.add_argument('-file_type', default='text', choices=['text', 'field'], required=True, help="""Options for vocabulary creation. The default is 'text' where the user passes a corpus or a list of corpora files for which they want to create a vocabulary from. If choosing the option 'field', we assume the file passed is a torch file created during the preprocessing stage of an already preprocessed corpus. The vocabulary file created will just be the vocabulary inside the field corresponding to the argument 'side'.""") parser.add_argument("-file", type=str, nargs="+", required=True) parser.add_argument("-out_file", type=str, required=True) parser.add_argument("-side", type=str) opt = parser.parse_args() vocabulary = {} if opt.file_type == 'text': print("Reading input file...") for batch in read_files_batch(opt.file): for sentence in batch: for w in sentence: if w in vocabulary: vocabulary[w] += 1 else: vocabulary[w] = 1 print("Writing vocabulary file...") with open(opt.out_file, "w") as f: for w, count in sorted(vocabulary.items(), key=lambda x: x[1], reverse=True): f.write("{0}\n".format(w)) else: import torch from onmt.inputters.inputter import _old_style_vocab print("Reading input file...") if not len(opt.file) == 1: raise ValueError("If using -file_type='field', only pass one " "argument for -file.") vocabs = torch.load(opt.file[0]) voc = dict(vocabs)[opt.side] if _old_style_vocab(voc): word_list = voc.itos else: try: word_list = voc[0][1].base_field.vocab.itos except AttributeError: word_list = voc[0][1].vocab.itos print("Writing vocabulary file...") with open(opt.out_file, "wb") as f: for w in word_list: f.write(u"{0}\n".format(w).encode("utf-8"))
def get_vocabs(dict_path): fields = torch.load(dict_path) vocs = [] for side in ['src', 'tgt']: if _old_style_vocab(fields): vocab = next((v for n, v in fields if n == side), None) else: try: vocab = fields[side].base_field.vocab except AttributeError: vocab = fields[side].vocab vocs.append(vocab) enc_vocab, dec_vocab = vocs logger.info("From: %s" % dict_path) logger.info("\t* source vocab: %d words" % len(enc_vocab)) logger.info("\t* target vocab: %d words" % len(dec_vocab)) return enc_vocab, dec_vocab
def get_vocabs(dict_path): fields = torch.load(dict_path) vocs = [] for side in ['ques', 'ans', 'tgt']: if _old_style_vocab(fields): vocab = next((v for n, v in fields if n == side), None) else: try: vocab = fields[side].base_field.vocab except AttributeError: vocab = fields[side].vocab vocs.append(vocab) ques_vocab, ans_vocab, dec_vocab = vocs assert ques_vocab == ans_vocab logger.info("From: %s" % dict_path) logger.info("\t* ques vocab: %d words" % len(ques_vocab)) logger.info("\t* ans vocab: %d words" % len(ans_vocab)) logger.info("\t* target vocab: %d words" % len(dec_vocab)) return ques_vocab, dec_vocab