def main(): filename = sys.argv[1] vec_filename = sys.argv[2] max_vocab = int(sys.argv[3]) pt = pretrain.Pretrain(filename, vec_filename, max_vocab) print("Pretrain is of size {}".format(len(pt.vocab)))
def load_pretrain(args): pt = None if args['pretrain']: pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang']) if os.path.exists(pretrain_file): vec_file = None else: vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand']) pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab']) return pt
def test_resave_pretrain(): """ Test saving a pretrain and then loading from the existing file """ test_pt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".pt", delete=False) try: test_pt_file.close() # note that this tests the ability to save a pretrain and the # ability to fall back when the existing pretrain isn't working pt = pretrain.Pretrain( filename=test_pt_file.name, vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz') check_pretrain(pt) pt2 = pretrain.Pretrain(filename=test_pt_file.name, vec_filename=f'unban_mox_opal') check_pretrain(pt2) finally: os.unlink(test_pt_file.name)
def load_pretrain(args): """ Loads a pretrain based on the paths in the arguments """ pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang']) if os.path.exists(pretrain_file): vec_file = None else: vec_file = args['wordvec_file'] if args[ 'wordvec_file'] else utils.get_wordvec_file( args['wordvec_dir'], args['shorthand']) pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab']) return pt
def test_whitespace(): """ Test reading a pretrain with an ascii space in it The vocab word with a space in it should have the correct number of dimensions read, with the space converted to nbsp """ test_txt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".txt", delete=False) try: test_txt_file.write(SPACE_PRETRAIN.encode()) test_txt_file.close() pt = pretrain.Pretrain(vec_filename=test_txt_file.name, save_to_file=False) check_embedding(pt.emb) assert "unban\xa0mox" in pt.vocab # this one also works because of the normalize_text in vocab.py assert "unban mox" in pt.vocab finally: os.unlink(test_txt_file.name)
help='Which treebanks to run on') parser.add_argument( '--pretrain', type=str, default="/home/john/extern_data/wordvec/glove/armenian.pt", help='Which pretrain to use') parser.set_defaults(treebanks=[ "/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Western_Armenian-ArmTDP/hyw_armtdp-ud-train.conllu", "/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Armenian-ArmTDP/hy_armtdp-ud-train.conllu" ]) args = parser.parse_args() return args args = parse_args() pt = pretrain.Pretrain(args.pretrain) pt.load() print("Pretrain stats: {} vectors, {} dim".format(len(pt.vocab), pt.emb[0].shape[0])) for treebank in args.treebanks: print(treebank) found = 0 total = 0 doc = CoNLL.conll2doc(treebank) for sentence in doc.sentences: for word in sentence.words: total = total + 1 if word.text in pt.vocab: found = found + 1
def test_xz_pretrain(): pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False) check_pretrain(pt)
def pt(): return pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False)