Exemplo n.º 1
0
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        initializer = tf.global_variables_initializer()
        sess = tf.Session()
        sess.run(initializer)

        saver.restore(sess, best_path) # restore the model

        system = CacheTransition(cache_size, oracle.utils.OracleType.CL)
        if use_dep:
            shiftpop, pushidx, arcbinary, arclabel = NP2P_data_stream.load_actions(in_path)
            system.shiftpop_action_set, system.push_action_set = shiftpop, pushidx
            system.arcbinary_action_set, system.arclabel_action_set = arcbinary, arclabel
            income_arc_choices, outgo_arc_choices, default_arc_choices = NP2P_data_stream.load_arc_choices(in_path)
            system.income_arcChoices, system.outgo_arcChoices = income_arc_choices, outgo_arc_choices
            system.default_arcChoices = default_arc_choices

        category_res = {feat_vocab.getIndex(x):[0.0,0.0,] for x in ('PHASE=PUSHIDX', 'PHASE=SHTPOP', 'PHASE=ARCBINARY', 'PHASE=ARCLABEL',)}
        devDataStream.reset()
        for i in range(devDataStream.get_num_batch()):
            cur_batch = devDataStream.get_batch(i)
            print('Instance {}'.format(i))
            run_beam_search(sess, system, valid_graph, feat_vocab, action_vocab, cur_batch, cache_size, FLAGS, category_res)
        for k,v in category_res.iteritems():
            k = feat_vocab.getWord(k)
            print('%s : %.4f %d/%d' %(k, v[1]/v[0], v[1], v[0]))


Exemplo n.º 2
0
        saver.restore(sess, best_path)  # restore the model

        system = CacheTransition(cache_size, oracle.utils.OracleType.CL)
        if use_dep:
            shiftpop, pushidx, arcbinary, arclabel = NP2P_data_stream.load_actions(
                in_path)
            system.shiftpop_action_set, system.push_action_set = shiftpop, pushidx
            system.arcbinary_action_set, system.arclabel_action_set = arcbinary, arclabel
            income_arc_choices, outgo_arc_choices, default_arc_choices = NP2P_data_stream.load_arc_choices(
                in_path)
            system.income_arcChoices, system.outgo_arcChoices = income_arc_choices, outgo_arc_choices
            system.default_arcChoices = default_arc_choices

        category_res = {
            feat_vocab.getIndex(x): [
                0.0,
                0.0,
            ]
            for x in (
                'PHASE=PUSHIDX',
                'PHASE=SHTPOP',
                'PHASE=ARCBINARY',
                'PHASE=ARCLABEL',
            )
        }
        devDataStream.reset()
        for i in range(devDataStream.get_num_batch()):
            cur_batch = devDataStream.get_batch(i)
            print('Instance {}'.format(i))
            run_beam_search(sess, system, valid_graph, feat_vocab,