def train():
    MODE = 'train'
    featExt = extractors.get(FLAGS.feature_extarctor)
    sents = io.transform_conll_sents(FLAGS.train_data, FLAGS.only_projective,
                                     FLAGS.unlex)
    trainer = MLTrainerActionDecider(ml.MultitronParameters(3),
                                     ArcStandardParsingOracle(), featExt)
    p = ArcStandardParser2(trainer)
    total = len(sents)
    random.seed("seed")
    for x in xrange(FLAGS.epoch):
        random.shuffle(sents)
        logging.info("iter %s/%s", x + 1, FLAGS.epoch)
        logging.info("  shuffle data ...")
        for i, sent in enumerate(sents):
            if i % 500 == 0:
                logging.info("  step %s/%s ...", i, total)
            try:
                d = p.parse(sent)
            except Exception as e:
                logging.info("prob in sent: %s", i)
                logging.info("\n".join([
                    "%s %s %s %s" % (t['id'], t['form'], t['tag'], t['parent'])
                    for t in sent
                ]))
                raise e

    with open(FLAGS.model, "w") as fout:
        logging.info("save model file to disk [%s] ...", FLAGS.model)
        trainer.save(fout)
Пример #2
0
def train():
    '''
    Train Model
    '''
    MODE = 'train'
    TRAIN_OUT_FILE = FLAGS.model

    if FLAGS.externaltrainfile:
        '''
        create feature vector files for training with an external classifier.  If you don't know what it means,
         just ignore this option.  The model file format is the same as Megam's.
        '''
        MODE = 'write'
        TRAIN_OUT_FILE = FLAGS.externaltrainfile

    featExt = extractors.get(FLAGS.feature_extarctor)
    sents = io.transform_conll_sents(FLAGS.train_data, FLAGS.only_projective,
                                     FLAGS.unlex)

    if MODE == "write":
        fout = file(TRAIN_OUT_FILE, "w")
        trainer = LoggingActionDecider(
            ArcEagerParsingOracle(pop_when_can=FLAGS.lazypop), featExt, fout)
        p = ArcEagerParser(trainer)
        for i, sent in enumerate(sents):
            sys.stderr.write(". %s " % i)
            sys.stderr.flush()
            d = p.parse(sent)
        sys.exit()

    if MODE == "train":
        fout = file(TRAIN_OUT_FILE, "w")
        nactions = 4
        trainer = MLTrainerActionDecider(
            ml.MultitronParameters(nactions),
            ArcEagerParsingOracle(pop_when_can=FLAGS.lazypop), featExt)
        p = ArcEagerParser(trainer)
        import random
        random.seed("seed")
        total = len(sents)
        for x in xrange(FLAGS.epoch):  # epoch
            logging.info("iter %s/%s", x + 1, FLAGS.epoch)
            logging.info("  shuffle data ...")
            random.shuffle(sents)
            for i, sent in enumerate(sents):
                if i % 500 == 0:
                    logging.info("  step %s/%s ...", i, total)
                try:
                    d = p.parse(sent)
                except IndexError as e:
                    logging.info("prob in sent: %s", i)
                    logging.info("\n".join([
                        "%s %s %s %s" %
                        (t['id'], t['form'], t['tag'], t['parent'])
                        for t in sent
                    ]))
                    raise e
        logging.info("save model file to disk [%s] ...", TRAIN_OUT_FILE)
        trainer.save(fout)
def test():
    featExt = extractors.get(FLAGS.feature_extarctor)
    sents = io.transform_conll_sents(FLAGS.test_data, FLAGS.only_projective,
                                     FLAGS.unlex)
    p = ArcStandardParser2(
        MLActionDecider(ml.MulticlassModel(FLAGS.model, True), featExt))
    good = 0.0
    bad = 0.0
    complete = 0.0

    with open(FLAGS.test_results, "w") as fout:
        for i, sent in enumerate(sents):
            mistake = False
            sgood = 0.0
            sbad = 0.0
            fout.write("%s %s %s\n" % ("@@@", i, good / (good + bad + 1)))
            try:
                d = p.parse(sent)
            except MLTrainerWrongActionException:
                continue

            sent = d.annotate(sent)
            for tok in sent:
                # print tok['id'], tok['form'], "_",tok['tag'],tok['tag'],"_",tok['pparent'],"_ _ _"
                if FLAGS.ignore_punc and tok['form'][0] in "`',.-;:!?{}":
                    continue
                if tok['parent'] == tok['pparent']:
                    good += 1
                    sgood += 1
                else:
                    bad += 1
                    sbad += 1
                    mistake = True
            if not mistake: complete += 1
            fout.write("%s\n" % (sgood / (sgood + sbad)))

    print("accuracy:", good / (good + bad))
    print("complete:", complete / len(sents))
Пример #4
0
def test():
    '''
    Test Model
    '''
    logging.info("test ...")
    featExt = extractors.get(FLAGS.feature_extarctor)
    p = ArcEagerParser(
        MLActionDecider(ml.MulticlassModel(FLAGS.model), featExt))

    good = 0.0
    bad = 0.0
    complete = 0.0

    # main test loop
    reals = set()
    preds = set()
    with open(FLAGS.test_results, "w") as fout:
        sents = io.transform_conll_sents(FLAGS.test_data,
                                         FLAGS.only_projective, FLAGS.unlex)
        for i, sent in enumerate(sents):
            sgood = 0.0
            sbad = 0.0
            mistake = False
            sys.stderr.write("%s %s %s\n" % ("@@@", i, good /
                                             (good + bad + 1)))
            try:
                d = p.parse(sent)
            except MLTrainerWrongActionException:
                # this happens only in "early update" parsers, and then we just go on to
                # the next sentence..
                continue
            sent = d.annotate_allow_none(sent)
            for tok in sent:
                if FLAGS.ignore_punc and tok['form'][0] in "`',.-;:!?{}":
                    continue
                reals.add((i, tok['parent'], tok['id']))
                preds.add((i, tok['pparent'], tok['id']))
                if tok['pparent'] == -1:
                    continue
                if tok['parent'] == tok['pparent'] or tok['pparent'] == -1:
                    good += 1
                    sgood += 1
                else:
                    bad += 1
                    sbad += 1
                    mistake = True
            if FLAGS.unlex:
                io.out_conll(sent, parent='pparent', form='oform')
            else:
                io.out_conll(sent, parent='pparent', form='form')
            if not mistake:
                complete += 1
            # sys.exit()
            logging.info("test result: sgood[%s], sbad[%s]", sgood, sbad)
            if sgood > 0.0 and sbad > 0.0:
                fout.write("%s\n" % (sgood / (sgood + sbad)))

        logging.info("accuracy: %s", good / (good + bad))
        logging.info("complete: %s", complete / len(sents))
        preds = set([(i, p, c) for i, p, c in preds if p != -1])
        logging.info("recall: %s",
                     len(preds.intersection(reals)) / float(len(reals)))
        logging.info("precision: %s",
                     len(preds.intersection(reals)) / float(len(preds)))
        logging.info("assigned: %s", len(preds) / float(len(reals)))
Пример #5
0
gflags.DEFINE_boolean("ignore_punc", True, "Ignore punctuation in evaluation.")
gflags.DEFINE_boolean("only_proj", True,
                      "If true, prune non-projective sentences in training.")
gflags.DEFINE_boolean("add_dep_label", True,
                      "If true, replace the '_' label with 'dep'.")
gflags.DEFINE_integer("random_seed", 0, "Random seed.")

gflags.DEFINE_integer("save_every", 0, "Dump a model every k iterations.")

args = FLAGS(sys.argv)
print args

DATA_FILE = args[1]

featExt = extractors.get(FLAGS.feature_extractor)

sents = list(io.conll_to_sents(file(DATA_FILE)))

if FLAGS.train and (True or FLAGS.only_proj):
    import isprojective
    sents = [s for s in sents if isprojective.is_projective(s)]

if FLAGS.add_dep_label:
    for sent in sents:
        for tok in sent:
            if tok['prel'] == '_': tok['prel'] = "dep"

EXPLORE = 1

LABELED = True