Ejemplo n.º 1
0
def load_model(doc_path="inference_data", is_savedmodel=False):
    params = inference_input(doc_path)

    data_loader = DataLoader(
        params,
        params.classes,
        update_dict=False,
        load_dictionary=True,
        data_split=0.0)  # False to provide a path with only test data
    num_words = max(20000, data_loader.num_words)
    num_classes = data_loader.num_classes
    # model
    if params.use_cutie2:
        network = CUTIEv2(num_words, num_classes, params)
    else:
        network = CUTIEv1(num_words, num_classes, params)
    model_output = network.get_output('softmax')

    if is_savedmodel:
        sess = load_savedmodel(params.savedmodel_dir)
    else:
        # evaluation
        ckpt_saver = tf.train.Saver()
        config = tf.ConfigProto(allow_soft_placement=True)
        sess = tf.Session(config=config)
        sess.run(tf.global_variables_initializer())
        try:
            ckpt_path = os.path.join(params.e_ckpt_path, params.save_prefix,
                                     params.ckpt_file)
            ckpt = tf.train.get_checkpoint_state(ckpt_path)
            print('Restoring from {}...'.format(ckpt_path))
            ckpt_saver.restore(sess, ckpt_path)
            print('{} restored'.format(ckpt_path))
        except:
            raise Exception('Check your pretrained {:s}'.format(ckpt_path))

    return network, model_output, sess
Ejemplo n.º 2
0
parser.add_argument('--embedding_size', type=int, default=128) 
parser.add_argument('--batch_size', type=int, default=1) 
parser.add_argument('--c_threshold', type=float, default=0.5) 
params = parser.parse_args()

if __name__ == '__main__':
    # data
    #data_loader = DataLoader(params, True, True) # True to use 25% training data
    data_loader = DataLoader(params, update_dict=False, load_dictionary=True, data_split=0.75) # False to provide a path with only test data
    num_words = max(20000, data_loader.num_words)
    num_classes = data_loader.num_classes

    # model
    if params.use_cutie2:
        network = CUTIEv2(num_words, num_classes, params)
    else:
        network = CUTIEv1(num_words, num_classes, params)
    model_output = network.get_output('softmax')
    
    # evaluation
    ckpt_saver = tf.train.Saver()
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        try:
            #ckpt_path = os.path.join(params.e_ckpt_path, params.save_prefix, params.ckpt_file)
            ckpt_path = '/content/content/CUTIE/graph/INVOICE/CUTIE_atrousSPP_best.ckpt'
            ckpt = tf.train.get_checkpoint_state(ckpt_path)
            print('Restoring from {}...'.format(ckpt_path))
            ckpt_saver.restore(sess, ckpt_path)