Ejemplo n.º 1
0
def main():
    filename = sys.argv[1]
    vec_filename = sys.argv[2]
    max_vocab = int(sys.argv[3])

    pt = pretrain.Pretrain(filename, vec_filename, max_vocab)
    print("Pretrain is of size {}".format(len(pt.vocab)))
Ejemplo n.º 2
0
def load_pretrain(args):
    pt = None
    if args['pretrain']:
        pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])
        if os.path.exists(pretrain_file):
            vec_file = None
        else:
            vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
        pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
    return pt
Ejemplo n.º 3
0
def test_resave_pretrain():
    """
    Test saving a pretrain and then loading from the existing file
    """
    test_pt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out',
                                               suffix=".pt",
                                               delete=False)
    try:
        test_pt_file.close()
        # note that this tests the ability to save a pretrain and the
        # ability to fall back when the existing pretrain isn't working
        pt = pretrain.Pretrain(
            filename=test_pt_file.name,
            vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz')
        check_pretrain(pt)

        pt2 = pretrain.Pretrain(filename=test_pt_file.name,
                                vec_filename=f'unban_mox_opal')
        check_pretrain(pt2)
    finally:
        os.unlink(test_pt_file.name)
Ejemplo n.º 4
0
def load_pretrain(args):
    """
    Loads a pretrain based on the paths in the arguments
    """
    pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'],
                                                args['save_dir'],
                                                args['shorthand'],
                                                args['lang'])
    if os.path.exists(pretrain_file):
        vec_file = None
    else:
        vec_file = args['wordvec_file'] if args[
            'wordvec_file'] else utils.get_wordvec_file(
                args['wordvec_dir'], args['shorthand'])
    pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
    return pt
Ejemplo n.º 5
0
def test_whitespace():
    """
    Test reading a pretrain with an ascii space in it

    The vocab word with a space in it should have the correct number
    of dimensions read, with the space converted to nbsp
    """
    test_txt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out',
                                                suffix=".txt",
                                                delete=False)
    try:
        test_txt_file.write(SPACE_PRETRAIN.encode())
        test_txt_file.close()

        pt = pretrain.Pretrain(vec_filename=test_txt_file.name,
                               save_to_file=False)
        check_embedding(pt.emb)
        assert "unban\xa0mox" in pt.vocab
        # this one also works because of the normalize_text in vocab.py
        assert "unban mox" in pt.vocab
    finally:
        os.unlink(test_txt_file.name)
Ejemplo n.º 6
0
                        help='Which treebanks to run on')
    parser.add_argument(
        '--pretrain',
        type=str,
        default="/home/john/extern_data/wordvec/glove/armenian.pt",
        help='Which pretrain to use')
    parser.set_defaults(treebanks=[
        "/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Western_Armenian-ArmTDP/hyw_armtdp-ud-train.conllu",
        "/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Armenian-ArmTDP/hy_armtdp-ud-train.conllu"
    ])
    args = parser.parse_args()
    return args


args = parse_args()
pt = pretrain.Pretrain(args.pretrain)
pt.load()
print("Pretrain stats: {} vectors, {} dim".format(len(pt.vocab),
                                                  pt.emb[0].shape[0]))

for treebank in args.treebanks:
    print(treebank)
    found = 0
    total = 0
    doc = CoNLL.conll2doc(treebank)
    for sentence in doc.sentences:
        for word in sentence.words:
            total = total + 1
            if word.text in pt.vocab:
                found = found + 1
Ejemplo n.º 7
0
def test_xz_pretrain():
    pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz',
                           save_to_file=False)
    check_pretrain(pt)
Ejemplo n.º 8
0
def pt():
    return pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz',
                             save_to_file=False)