Exemple #1
0
def demotest(sentence):
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            #print('Please input your sentence:')
            demo_sent = sentence
            if demo_sent == '' or demo_sent.isspace():
                print('语句为空')
                PER = ['']
                LOC = ['']
                ORG = ['']
                return (PER, LOC, ORG)
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
                return (PER, LOC, ORG)
Exemple #2
0
def evaluate_words(lines):
    print("start evaluate_words")
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)

        demo_sent = lines
        print(demo_sent)
        demo_sent = list(demo_sent.strip())
        print(demo_sent)
        demo_data = [(demo_sent, ['O'] * len(demo_sent))]
        tag = model.demo_one(sess, demo_data)
        PER, LOC, ORG = get_entity(tag, demo_sent)
        print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Exemple #3
0
 def get_org(self, text):
     demo_sent = text
     demo_sent = list(demo_sent.strip())
     demo_data = [(demo_sent, ['O'] * len(demo_sent))]
     tag = self.model.demo_one(self.sess, demo_data)
     per, pos, org = get_entity(tag, demo_sent)
     return set(org)
def save_result(saver,sess,decoder,ckpt_file):
    saver.restore(sess, ckpt_file)
    f = open("./data/case/input.txt","r")
    sentences = f.readlines()
    f.close() 
    result=[]
    for step, sen in enumerate(sentences):
        sys.stdout.write(' processing: {} sentence / {} sentences.'.format(step + 1, len(sentences)) + '\r')
        instance={}
        sen_ = list(sen.strip().strip('\r\n'))
        char_list = [(sen_, ['O'] * len(sen_))]
        tag = decoder.demo_one(sess, char_list)
        PER, LOC, ORG = get_entity(tag, sen_)
        #print(sen)
        #print("PER",PER)
        #print("LOC",LOC)
        #print("ORG",ORG)
        instance['sen']=sen
        instance['PER']=PER
        instance['LOC']=LOC
        instance['ORG']=ORG
        result.append(instance)
    with open('./data/case/result.json','w',encoding='utf-8')as fw:
        fw.write(json.dumps(result,ensure_ascii=False))
    print("********The result is saved in the ./data/case/result.json*********"+ '\r')        
Exemple #5
0
def getRest(input):
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        demo_sent = input
        if demo_sent == '' or demo_sent.isspace():
            return {'status': 'fail'}
        else:
            demo_sent = list(demo_sent.strip())
            demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            tag = model.demo_one(sess, demo_data)
            PER, LOC, ORG = get_entity(tag, demo_sent)
            result = {'status': 'success', 'PER': PER, 'LOC': LOC, 'ORG': ORG}
            return result
Exemple #6
0
    def handle_POST(self, request, context, scheme, value):
        entity = context['entity'] = get_entity(scheme, value)
        if entity.source.module_name != 'molly.providers.apps.maps.osm':
            raise Http404

        form = UpdateOSMForm(request.POST)
        if form.is_valid():
            new_metadata = copy.deepcopy(entity.metadata['osm'])
            for k in ('name', 'operator', 'phone', 'opening_hours', 'url', 'cuisine', 'food', 'food__hours', 'atm', 'collection_times', 'ref', 'capacity'):
                tag_name = k.replace('__', ':')
                if tag_name in new_metadata and not form.cleaned_data[k]:
                    del new_metadata['osm']['tags'][tag_name]
                elif form.cleaned_data[k]:
                    new_metadata['tags'][tag_name] = form.cleaned_data[k]

            new_metadata['attrs']['version'] = str(int(new_metadata['attrs']['version'])+1)

            osm_update = OSMUpdate(
                contributor_name = form.cleaned_data['contributor_name'],
                contributor_email = form.cleaned_data['contributor_email'],
                contributor_attribute = form.cleaned_data['contributor_attribute'],
                entity = entity,
                old = simplejson.dumps(entity.metadata),
                new = simplejson.dumps(new_metadata),
                notes = form.cleaned_data['notes'],
            )
            osm_update.save()

            return HttpResponseRedirect(reverse('places:entity_update', args=[scheme, value])+'?submitted=true')
        else:
            context['form'] = form
            return self.render(request, context, 'places/update_osm')
Exemple #7
0
 def get_metadata(self, request, scheme, value):
     entity = get_entity(scheme, value)
     user_location = request.session.get("geolocation:location")
     distance, bearing = entity.get_distance_and_bearing_from(user_location)
     additional = "<strong>%s</strong>" % capfirst(entity.primary_type.verbose_name)
     if distance:
         additional += ", approximately %.3fkm %s" % (distance / 1000, bearing)
     return {"title": entity.title, "additional": additional}
Exemple #8
0
 def initial_context(self, request, scheme, value):
     context = super(EntityDetailView, self).initial_context(request)
     entity = get_entity(scheme, value)
     context.update({
         'entity': entity,
         'entity_types': entity.all_types.all(),
     })
     return context
Exemple #9
0
def deploy_model(d):
    vocab = read_dictionary(os.path.join(os.getcwd(),
                                         'data_path/word2id1.pkl'))

    with tf.Session() as sess:
        signature_key = 'test_signature'
        input_key = 'input_x'
        input_key2 = 'sequence_length'
        output_key = 'output'
        output_key2 = 'transition_param'

        meta_graph_def = tf.saved_model.loader.load(
            sess, ['test_saved_model'],
            os.path.join(os.getcwd(), 'model/1506177919'))
        # 从 meta_graph_def 中取出 SignatureDef 对象
        signature = meta_graph_def.signature_def

        # 从 signature 中找到具体输入输出的 tensor name
        x_tensor_name = signature[signature_key].inputs[input_key].name
        x_tensor_name2 = signature[signature_key].inputs[input_key2].name
        y_tensor_name = signature[signature_key].outputs[output_key].name
        y_tensor_name2 = signature[signature_key].outputs[output_key2].name

        word_ids = sess.graph.get_tensor_by_name(x_tensor_name)
        sequence_lengths = sess.graph.get_tensor_by_name(x_tensor_name2)
        y = sess.graph.get_tensor_by_name(y_tensor_name)
        y2 = sess.graph.get_tensor_by_name(y_tensor_name2)

        sent = d['1']
        input_sent = list(sent.strip())
        input_data = [(input_sent, ['O'] * len(input_sent))]
        for seqs, labels in batch_yield(input_data, vocab, tag2label):
            feed_dict, seq_len_list = get_feed_dict(seqs, labels)

            [logits, transition_params] = sess.run([y, y2],
                                                   feed_dict={
                                                       word_ids:
                                                       seqs,
                                                       sequence_lengths:
                                                       seq_len_list
                                                   })

            label_list = []

            for logit in logits:
                viterbi_seq, _ = viterbi_decode(logit[:seq_len_list[0]],
                                                transition_params)
                label_list.append(viterbi_seq)

            label2tag = {}
            for tag, label in tag2label.items():
                label2tag[label] = tag if label != 0 else label
            tag = [label2tag[label] for label in label_list[0]]
            PER, LOC, ORG, OTH = get_entity(tag, input_sent)
            return ('PER: {}\nLOC: {}\nORG: {}\nOTH: {}'.format(
                PER, LOC, ORG, OTH))
Exemple #10
0
 def predict(self, demo_sent):
     if demo_sent == '' or demo_sent.isspace():
         print('See you next time!')
         return {}
     else:
         demo_sent = list(demo_sent.strip())
         demo_data = [(demo_sent, ['O'] * len(demo_sent))]
         tag = self.model.demo_one(self.sess, demo_data)
         entities = get_entity(tag, demo_sent)
         return entities
Exemple #11
0
def do_lstm(sentence):

    global sess, model

    per = []
    loc = []
    org = []

    # with tf.Session(config=config) as sess:
    if sess is not None:
        print('============= demo =============')

        if len(sentence) > 10:

            demo_sent = sentence.replace(u" ", u",").replace(u"《", "").replace(
                u"》", "").replace(" ", "").replace(",",
                                                   u",").replace(u"[", "")
            demo_sent = demo_sent.replace(u"]", u"").replace(u"(", "").replace(
                u")", "").replace(u"—", "").replace(u"〔",
                                                    " ").replace(u"〕", " ")
            demo_sent = demo_sent.replace(u""", u"").replace(u"“", "").replace(
                u"”", "").replace("...", "").replace(u"⒄", "")

            _sent = [demo_sent]
            for _s in _sent:
                if len(_s) < 10:
                    continue
                # print("{}".format(_s))
                _sent = list(_s.strip())
                _data = [(_sent, ['O'] * len(_sent))]
                # print(_data)
                tag = model.demo_one(sess, _data)
                try:
                    PER, LOC, ORG = get_entity(tag, _sent)
                    # print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
                    if len(PER) > 0:
                        for _p in PER:
                            if len(_p) > 1 and _p not in per:
                                print('PER: {}'.format(_p))
                                per.append(_p)
                    if len(LOC) > 0:
                        for _p in LOC:
                            if len(_p) > 1 and _p not in loc:
                                print('LOC: {}'.format(_p))
                                loc.append(_p)
                    if len(ORG) > 0:
                        for _p in ORG:
                            if len(_p) > 1 and _p not in org:
                                print('ORG: {}'.format(_p))
                                org.append(_p)
                except Exception as e:
                    print(e)

    return per, loc, org
Exemple #12
0
 def get_metadata(self, request, scheme, value):
     entity = get_entity(scheme, value)
     user_location = request.session.get('geolocation:location')
     distance, bearing = entity.get_distance_and_bearing_from(user_location)
     additional = '<strong>%s</strong>' % capfirst(entity.primary_type.verbose_name)
     if distance:
         additional += ', approximately %.3fkm %s' % (distance/1000, bearing)
     return {
         'title': entity.title,
         'additional': additional,
     }
Exemple #13
0
 def breadcrumb(self, request, context, scheme, value):
     if request.session.get("geolocation:location"):
         parent_view = "nearby-detail"
     else:
         parent_view = "category-detail"
     entity = get_entity(scheme, value)
     return Breadcrumb(
         "places",
         lazy_parent(parent_view, ptypes=entity.primary_type.slug),
         context["entity"].title,
         lazy_reverse("entity", args=[scheme, value]),
     )
Exemple #14
0
 def breadcrumb(self, request, context, scheme, value):
     if request.session.get('geolocation:location'):
         parent_view = 'nearby-detail'
     else:
         parent_view = 'category-detail'
     entity = get_entity(scheme, value)
     return Breadcrumb(
         'places',
         lazy_parent(parent_view, ptypes=entity.primary_type.slug),
         context['entity'].title,
         lazy_reverse('entity', args=[scheme,value]),
     )
Exemple #15
0
def page(request, entity_type, entity_id, layout_slug=None):
	try:
		entity = get_entity(entity_type, pk=int(entity_id))			
	except:
		import traceback
		traceback.print_exc()
		raise Http404
	try:
		__layout__ = get_layout(slug=layout_slug, entity_type=entity_type, entity=entity)
	except:
		__layout__ = Layout.get()
	output = locals()
	context = RequestContext(request)
	return render_to_response(__layout__.get_template(), output, context)
Exemple #16
0
    def handle_POST(self, request, context, scheme, value):
        entity = context["entity"] = get_entity(scheme, value)
        if entity.source.module_name != "molly.providers.apps.maps.osm":
            raise Http404

        form = UpdateOSMForm(request.POST)
        if form.is_valid():
            new_metadata = copy.deepcopy(entity.metadata["osm"])
            for k in (
                "name",
                "operator",
                "phone",
                "opening_hours",
                "url",
                "cuisine",
                "food",
                "food__hours",
                "atm",
                "collection_times",
                "ref",
                "capacity",
            ):
                tag_name = k.replace("__", ":")
                if tag_name in new_metadata and not form.cleaned_data[k]:
                    del new_metadata["osm"]["tags"][tag_name]
                elif form.cleaned_data[k]:
                    new_metadata["tags"][tag_name] = form.cleaned_data[k]

            new_metadata["attrs"]["version"] = str(int(new_metadata["attrs"]["version"]) + 1)

            osm_update = OSMUpdate(
                contributor_name=form.cleaned_data["contributor_name"],
                contributor_email=form.cleaned_data["contributor_email"],
                contributor_attribute=form.cleaned_data["contributor_attribute"],
                entity=entity,
                old=simplejson.dumps(entity.metadata),
                new=simplejson.dumps(new_metadata),
                notes=form.cleaned_data["notes"],
            )
            osm_update.save()

            return HttpResponseRedirect(reverse("places:entity-update", args=[scheme, value]) + "?submitted=true")
        else:
            context["form"] = form
            return self.render(request, context, "places/update_osm")
Exemple #17
0
def demo(model, config, word2id):
    ckpt_file = tf.train.latest_checkpoint(config.model_dir)
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=ckpt_file)  # 读取保存的模型

    while True:
        print('Please input your sentence:')
        demo_sent = input()
        if demo_sent == '' or demo_sent.isspace():
            print('See you next time!')
            break
        else:
            demo_sent = list(demo_sent.strip())
            demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            tag = demo_one(session, model, config, word2id, demo_data)
            PER, LOC, ORG = get_entity(tag, demo_sent)
            print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Exemple #18
0
def test101(**kwargs):
    import argparse
    from utils import str2bool
    from data import read_dictionary, tag2label

    print('test101', kwargs)

    ##
    parser = argparse.ArgumentParser(
        description='BiLSTM-CRF for Chinese NER task')
    parser.add_argument('--train_data',
                        type=str,
                        default='data_path',
                        help='train data source')
    parser.add_argument('--demo_model',
                        type=str,
                        default='1521112368',
                        help='model for test and demo')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='#sample of each minibatch')
    args = parser.parse_args([])

    ##
    word2id = read_dictionary(os.path.join('.', args.train_data,
                                           'word2id.pkl'))

    client = BiLSTM_CRF_Client(args, tag2label, word2id)

    demo_sent = kwargs.get("demo_sent")
    demo_sent = list(demo_sent.strip())
    print('demo_sent', len(demo_sent))
    demo_data = [(demo_sent, ['O'] * len(demo_sent))]

    ret1 = client.demo_one(kwargs.get("server"), demo_data, verbose=False)

    print('result-1', ret1)

    from utils import get_entity

    PER, LOC, ORG = get_entity(ret1, demo_sent)
    print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Exemple #19
0
def ner_predict(cases, encrypt=False):
    result = []
    for case_str in cases:
        pc = prpcrypt(get_config().get("encryptKey", "key"))
        #process the string or listinfo
        if encrypt:
            case_str = pc.decrypt(bytes(case_str, encoding='utf-8'))
            case_str = re.sub("\|", "", case_str)
            case_str = re.sub("    ", "", case_str)
            case_str = re.sub("\n", "", case_str)
            case_str = re.sub(" ", "", case_str)
        predict_sent = list(case_str.strip())
        demo_data = [(predict_sent, ['O'] * len(predict_sent))]
        tag = model.demo_one(sess, demo_data)
        LOC = get_entity(tag, predict_sent)
        res = {
            "loc": list(set(LOC)),
        }
        result.append(res["loc"])
    return result
Exemple #20
0
def entity_favourite(request, type_slug, id):
    entity = get_entity(type_slug, id)

    if request.method != 'POST':
        return HttpResponse('', mimetype='text/plain', status=405)

    try:
        value = request.POST['is_favourite'] == 'true'
    except KeyError:
        return HttpResponse('', mimetype='text/plain', status=400)

    make_favourite(request, entity, value)

    if 'no_redirect' in request.POST:
        return HttpResponse('', mimetype='text/plain', status=400)

    if 'return_url' in request.POST:
        return HttpResponseRedirect(request.POST['return_url'])
    else:
        return HttpResponseRedirect(entity.get_absolute_url())
Exemple #21
0
def entity_favourite(request, type_slug, id):
    entity = get_entity(type_slug, id)

    if request.method != "POST":
        return HttpResponse("", mimetype="text/plain", status=405)

    try:
        value = request.POST["is_favourite"] == "true"
    except KeyError:
        return HttpResponse("", mimetype="text/plain", status=400)

    make_favourite(request, entity, value)

    if "no_redirect" in request.POST:
        return HttpResponse("", mimetype="text/plain", status=400)

    if "return_url" in request.POST:
        return HttpResponseRedirect(request.POST["return_url"])
    else:
        return HttpResponseRedirect(entity.get_absolute_url())
def main(test_sent):
    start_time = time.time()
    channel = implementations.insecure_channel('192.168.1.210', 5075)
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

    test_sent = list(test_sent.strip())
    test_data = [(test_sent, ['O'] * len(test_sent))]
    label_list = []
    for seqs, labels in batch_yield(test_data,
                                    batch_size=64,
                                    vocab=word2id,
                                    tag2label=tag2label,
                                    shuffle=False):
        label_list_, _ = predict_one_batch(seqs, stub)
        label_list.extend(label_list_)
    # label2tag = {}
    # for tag, label in tag2label.items():
    #     label2tag[label] = tag if label != 0 else label
    tag = [label2tag[label] for label in label_list[0]]
    print 'tag', tag
    PER, LOC, ORG = get_entity(tag, test_sent)
    time_used = time.time() - start_time
    print 'tim_used', time_used
    return PER, LOC, ORG
Exemple #23
0
 def handle_GET(self, request, context, scheme, value, ptype):
     entity = get_entity(scheme, value)
     return super(NearbyEntityDetailView, self).handle_GET(request, context, ptype, entity)
Exemple #24
0
 def get_metadata(self, request, scheme, value, ptype):
     entity = get_entity(type_slug, id)
     return super(NearbyEntityDetailView, self).get_metadata(request, ptype, entity)
Exemple #25
0
 def initial_context(self, request, scheme, value, ptype):
     entity = get_entity(scheme, value)
     context = super(NearbyEntityDetailView, self).initial_context(request, ptype, entity)
     context["entity"] = entity
     return context
Exemple #26
0
    print("test data: {}".format(test_size))
    model.test(test_data)

## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                ENT, HYPER = get_entity(tag, demo_sent)
                print('ENT: {}\nHYPER: {}'.format(ENT, HYPER))
Exemple #27
0
    model.test(test_data)

## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = model.BiLSTM_CRF(args,
                             embeddings,
                             data.tag2label,
                             word2id,
                             paths,
                             config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = utils.get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Exemple #28
0
                charge = one_text['charge']  # 犯罪原因
                judgementId = one_text['judgementId']  # 判决Id,唯一标示
                keywords = one_text['keywords']  # 关键词
                court = one_text['court']  # 法院信息
                judge_text = one_text['judge_text']  # 判决结果,是一个列表,继续循环
                proponents = one_text['proponents']  # 原告
                opponents = one_text['opponents']  # 被告

                for text in judge_text:  # 处理判决结果
                    text = re.sub("<a.+?</a>", '', text)
                    if text is '':
                        continue
                    print('judge_text: ', text)
                    demo_data = [(text, ['O'] * len(text))]
                    tag = model.demo_one(sess, demo_data)
                    PER, LOC, ORG = get_entity(tag, text)
                    MON = get_MON_entity(text)
                    print('PER: {}\nLOC: {}\nORG: {}\nMON: {}\n'.format(
                        PER, LOC, ORG, MON))

                    # 将数据写入es
                    es.index(index='zhizhuxia_sichuan',
                             doc_type='ner_type',
                             body={
                                 'addr': addr,
                                 'charge': charge,
                                 'judgementId': judgementId,
                                 'keywords': keywords,
                                 'court': court,
                                 'judge_text': text,
                                 'PER': PER,
Exemple #29
0
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            print('输入待识别句子:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                # 邱实在实验室买戴尔的显示屏
                # ['邱', '实', '在', '实', '验', '室', '买', '戴', '尔', '的', '显', '示', '屏']
                demo_sent = list(demo_sent.strip())
                # [(
                # ['邱', '实', '在', '实', '验', '室', '买', '戴', '尔', '的', '显', '示', '屏'],
                # ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
                # )]
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                # ['B-PER', 'I-PER', 0, 0, 0, 0, 0, 'B-PER', 'I-PER', 0, 0, 0, 0]
                tag = model.demo_one(sess, demo_data)
                # PER, LOC, ORG = get_entity(tag, demo_sent)
                targets = get_entity(tag, demo_sent)
                print('属性词: {}'.format(targets))

# 2019.03.07 15:45 start train
Exemple #30
0
 def initial_context(self, request, scheme, value):
     return dict(super(EntityUpdateView, self).initial_context(request), entity=get_entity(scheme, value))
Exemple #31
0
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                # name 参照 data.py tag2label
                PER = get_entity(tag, demo_sent, 'PER')
                SEX = get_entity(tag, demo_sent, 'SEX')
                TIT = get_entity(tag, demo_sent, 'TIT')
                REA = get_entity(tag, demo_sent, 'PER')
                print('PER: {}\nSEX: {}\nTIT: {}\nREA: {}'.format(
                    PER, SEX, TIT, REA))

## predict
elif args.mode == "predict":
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
 ckpt_file = tf.train.latest_checkpoint(model_path)
 model = BLC(batch_size=args['batch_size'],
             epoch_num=args['epoch'],
             hidden_dim=args['hidden_dim'],
             embeddings=embeddings,
             dropout_keep=args['dropout'],
             optimizer=args['optimizer'],
             lr=args['lr'],
             clip_grad=args['clip'],
             tag2label=tag2label,
             vocab=word2id,
             shuffle=args['shuffle'],
             model_path=ckpt_file,
             summary_path=summary_path,
             result_path=result_path,
             update_embedding=args['update_embedding'])
 model.build_graph()
 saver = tf.train.Saver()
 with tf.Session() as sess:
     saver.restore(sess, ckpt_file)
     while (1):
         print('输入待识别的句子: ')
         sent = input()
         if sent == '' or sent.isspace():
             break
         else:
             sent = list(sent.strip())
             data = [(sent, ['O'] * len(sent))]
             tag = model.test(sess, data)
             ENT = get_entity(tag, sent)
             print('ENT: {}\n'.format(ENT))
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while(1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                ENT, EVA, ALL = get_entity(tag, demo_sent)
                print('ENT: {}\nEVA: {}\nALL: {}\n'.format(ENT, EVA, ALL))
elif args.mode == 'all':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print('ckpt_file:',ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        result = open('result.txt', 'w',encoding='utf8')
        with open('content.txt', encoding='utf8') as f:
            count = 0
            error_count = 0
Exemple #34
0
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    print("test data: {}".format(test_size))
    model.test(test_data)

## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while(1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG, UNI, JOB = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}\nUNI: {}\nJOB: {}'.format(PER, LOC, ORG, UNI, JOB))
Exemple #35
0
 i = 0
 with open('data_path/es_data/es_weibo_data',
           'r',
           encoding='gb18030') as f:
     for line in f:
         i += 1
         print(i)
         demo_sent = line.strip()
         demo_sent = list(demo_sent)
         #print (demo_sent)
         demo_data = [(demo_sent, ['Neg_B'] * len(demo_sent))]
         #print (demo_data)
         #print (len(demo_sent))
         tag = model.demo_one(sess, demo_data)
         #print (tag)
         Neg, Pos, Neu = get_entity(tag, demo_sent)
         flag = 0
         posList = []
         negList = []
         for n in Neg:
             if Neg_dic.get(n) is None and len(n) > 1:
                 #fo.write(line.strip()+'\t'+'Neg:')
                 flag = 1
                 Neg_dic[n] = -1
                 negList.append(n)
                 #fo.write(n +' ')
         for p in Pos:
             if Pos_dic.get(p) is None and len(p) > 1:
                 flag = 1
                 Pos_dic[p] = 1
                 posList.append(p)
Exemple #36
0
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    print("test data: {}".format(test_size))
    model.test(test_data)

## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while(1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Exemple #37
0
             embeddings=embeddings,
             dropout_keep=args['dropout'],
             optimizer=args['optimizer'],
             lr=args['lr'],
             clip_grad=args['clip'],
             tag2label=tag2label,
             vocab=word2id,
             shuffle=args['shuffle'],
             model_path=ckpt_file,
             summary_path=summary_path,
             result_path=result_path,
             update_embedding=args['update_embedding'])
 model.build_graph()
 saver = tf.train.Saver()
 with tf.Session() as sess:
     saver.restore(sess, ckpt_file)
     while (1):
         print('输入待识别的句子: ')
         sent = input()
         if sent == '' or sent.isspace():
             break
         else:
             sent = list(sent.strip())
             data = [(sent, ['O'] * len(sent))]
             tag = model.test(sess, data)
             #ENT = get_entity(tag, sent)
             PER, LOC, ORG = get_entity(tag, sent)
             #print('ENT: {}\n'.format(ENT))
             print('PER: {}\n'.format(PER))
             print('LOC: {}\n'.format(LOC))
             print('ORG: {}\n'.format(ORG))
Exemple #38
0
 def get_metadata(self, request, scheme, value):
     entity = get_entity(scheme, value)
     return super(NearbyEntityListView, self).get_metadata(request, entity)
Exemple #39
0
    print('model_path:', model_path)
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(FLAGS=FLAGS,
                       embeddings=embeddings,
                       server=None,
                       num_workers=None,
                       word2id=word2id,
                       tag2label=tag2label,
                       paths=paths,
                       train_data_len=None)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while 1:
            print('Please input your sentence:')
            demo_sent = input('input:')
            if demo_sent == '' or demo_sent.isspace():
                print('see you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG, DUTY = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}\nDUTY: {}'.format(
                    PER, LOC, ORG, DUTY))
def predictor(_):
    """Triggered by HTTP.
    """
    start_time = time.time()
    client_datastore = datastore.Client()
    # Then get by key for this entity
    query_queue = client_datastore.query(kind='Queue')
    queue = list(query_queue.fetch())

    if len(queue) == 0:
        print("No frames to process")
        return
    print("{} frames to process".format(len(queue)))

    # test
    print('Number of duplicates in the queue',
          count_duplicates(list(map(lambda x: dict(x)['frame'], queue))))

    # Above a certain amount of frames in the queue we pick batch instead of online predictions
    # Or if there is currently a batch being prepared and other input files are waiting
    # to be written (checking if any queued frame has the key 'batch')
    if len(queue) >= TRESHOLD or any(map(lambda x: 'batch' in dict(x), queue)):
        if any_job_running(
                PROJECT_ID
        ):  # TODO: think about how to plan for later a job launch (after end of this one)
            print("A job is already running, exiting")
            return

        # Instantiates a GCS client
        storage_client = storage.Client()

        # Required in case building an input takes multiple predictor execution,
        # We want to keep the same jobId than previous execution because
        # jobId is used for the input / output folder (we want all inputs in the same folder)
        jobIds = list(filter(lambda x: 'batch' in dict(x), queue))

        # The model to prepare inputs is either the first of the list
        # either it's the first element of the list which has 'batch' property (means that a job was being prepared already)
        model = jobIds[0]['model'] if len(jobIds) > 0 else dict(
            queue[0])['model']
        body = make_batch_job_body(project_name=PROJECT_ID,
                                   bucket_name=BUCKET_NAME,
                                   model_name=MODEL_NAME,
                                   region=REGION,
                                   version_name=model,
                                   max_worker_count=72)
        # Filtering the queue to launch a batch job only for the asked model
        filtered_queue = list(
            filter(lambda x: model in dict(x)['model'], queue))
        if len(jobIds) > 0:
            body['jobId'] = jobIds[0]['batch']

            # Also need to update paths
            body['predictionInput'][
                'inputPaths'] = 'gs://{}/{}/batches/*'.format(
                    BUCKET_NAME, body['jobId'])
            body['predictionInput'][
                'outputPath'] = 'gs://{}/{}/batch_results'.format(
                    BUCKET_NAME, body['jobId'])
        else:  # Optimization, so we do it only once
            # We want to signal that all these frames have to be put into input files
            # Doing it first right away because it takes some execution time
            # Actually we just need to tag the last element of the batch
            filtered_queue[-1]['batch'] = body['jobId']
            client_datastore.put(filtered_queue[-1])

        # Creating multiple small input files (better scalability)
        for i, chunk in enumerate(chunks(filtered_queue, BATCH_CHUNK)):
            print('Chunk n°{}'.format(i + 1))
            elapsed_time = time.time() - start_time
            print('Elapsed time {0:.2f}'.format(elapsed_time))
            # Avoid timeout (40s)
            if elapsed_time > 40:
                # Resursive until everything into input files
                get_no_response(
                    'https://{}-{}.cloudfunctions.net/predictor'.format(
                        REGION, PROJECT_ID))
                print(
                    f'{len(filtered_queue) - (1 + i) * BATCH_CHUNK} frames left to write to input file for model {model}'
                )
                return
            random_file_id = random_id()
            for i, q in enumerate(chunk):
                frame_entity = get_entity(client_datastore, 'Frame',
                                          dict(q)['frame'])

                json_frame = frame_to_input(frame_entity)

                # Random name, must be different from other input files
                file_name = "inputs-{}.json".format(random_file_id)
                file_path = os.path.join("/tmp", file_name)

                # Open file with "a" = append the file
                with open(file_path, "a+") as json_file:
                    json_file.write(json.dumps(json_frame) + "\n")

                client_datastore.delete(q.key)

            bucket = storage_client.get_bucket(BUCKET_NAME)
            blob = bucket.blob(
                os.path.join(body['jobId'], 'batches', file_name))
            # Upload the input
            blob.upload_from_filename(file_path)

        # Launch the batch prediction job
        response = batch_predict(PROJECT_ID, body)
        # Dismiss processed messages from the  queue in case the job has been queued only
        if 'QUEUED' in response:
            pass

        return

    else:
        import socket
        # https://stackoverflow.com/questions/48969145/how-to-set-the-request-timeout-in-google-ml-api-python-client
        socket.setdefaulttimeout(ONLINE_TIMEOUT)
        # Iterate through the frames to process
        for i, q in enumerate(queue):

            elapsed_time = time.time() - start_time
            print('Elapsed time {0:.2f}'.format(elapsed_time))

            # Avoid timeout (40s)
            if elapsed_time > 30:  # TODO: handle timeout treshold in relation to image size (model takes longer for bigger image)
                # Resursive until everything processed
                get_no_response(
                    'https://{}-{}.cloudfunctions.net/predictor'.format(
                        REGION, PROJECT_ID))
                print('{} frames left to process'.format(len(queue) - i))
                return

            frame_entity = get_entity(client_datastore, 'Frame',
                                      dict(q)['frame'])
            json_frame = frame_to_input(frame_entity)
            instances = [json_frame]

            # Query AI Platform with the input
            result = online_predict(PROJECT_ID, MODEL_NAME, instances,
                                    dict(q)['model'])

            # Put the prediction in Datastore
            key_prediction = client_datastore.key('Prediction')
            entity_prediction = datastore.Entity(key=key_prediction)

            keys_object = list()

            # For each object detected ...
            # Assuming there is only one prediction possible even though there is a 's' at predictions ?
            for i in range(int(result['predictions'][0]['num_detections'])):
                if result['predictions'][0]['detection_scores'][i] > 0.1:
                    # Create a new dict that will be put in datastore in a clean format
                    object_detected = dict()
                    object_detected['detection_classes'] = int(
                        result['predictions'][0]['detection_classes'][i])
                    object_detected['detection_boxes'] = result['predictions'][
                        0]['detection_boxes'][i]
                    object_detected['detection_scores'] = result[
                        'predictions'][0]['detection_scores'][i]

                    # Put the information about the object into a new table row ...
                    key_object = client_datastore.key('Object')
                    entity_object = datastore.Entity(key=key_object)
                    entity_object.update(object_detected)
                    client_datastore.put(entity_object)

                    # Store the id generated for reference in Prediction table
                    keys_object.append(entity_object.id)

            # Put a list of objects detected in prediction row
            entity_prediction.update({
                'model': dict(q)['model'],
                'objects': keys_object
            })
            client_datastore.put(entity_prediction)

            # Update the predictions properties of the Frame row
            if 'processing' in frame_entity['predictions']:
                frame_entity['predictions'] = []  # Reset it
            # If it doesn't go in the if, it means that it already has predictions from another model
            frame_entity['predictions'].append(entity_prediction.id)

            # Push into datastore
            client_datastore.put(frame_entity)

            # Dismiss processed messages from the  queue
            # Remove from datastore
            client_datastore.delete(q.key)

    return
Exemple #41
0
 def initial_context(self, request, scheme, value):
     return {"entity": get_entity(scheme, value)}
Exemple #42
0
    with tf.Session(config=model.config) as sess:
        sess.run(tf.global_variables_initializer())
        model.train(sess=sess, train=train_data, dev=train_data, saver=saver)

elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    log_path['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, tag2label, vocab, log_path, logger, config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('Start demo ...')
        saver.restore(sess, ckpt_file)
        while True:
            print('Please input sentence(pause enter or space to exit):')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('Error for input format, see you next time!')
                break
            else:
                try:
                    demo_sent = list(demo_sent.strip())
                    demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                    tag = model.demo_one(sess, demo_data)
                    NOR, VER, ENG, OTH = get_entity(tag, demo_sent)
                    print('NOR: {}, VER: {}, ENG: {}, OTH: {}'.format(
                        NOR, VER, ENG, OTH))
                except:
                    print('Please switch to manual service ...')
Exemple #43
0
def trainAll(args):

    if args.mode == 'train':
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()

        ## hyperparameters-tuning, split train/dev
        # dev_data = train_data[:5000]; dev_size = len(dev_data)
        # train_data = train_data[5000:]; train_size = len(train_data)
        # print("train data: {0}\ndev data: {1}".format(train_size, dev_size))
        # model.train(train=train_data, dev=dev_data)

        ## train model on the whole training data
        print("train data: {}".format(len(train_data)))
        model.train(
            train=train_data, dev=test_data
        )  # use test_data as the dev_data to see overfitting phenomena

    ## testing model
    elif args.mode == 'test':
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print(ckpt_file)
        paths['model_path'] = ckpt_file
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        print("test data: {}".format(test_size))
        model.test(test_data)

    ## demo
    elif args.mode == 'demo':
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print(ckpt_file)
        paths['model_path'] = ckpt_file
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            print('============= demo =============')
            saver.restore(sess, ckpt_file)
            while (1):
                print('Please input your sentence:')
                demo_sent = input()
                if demo_sent == '' or demo_sent.isspace():
                    print('See you next time!')
                    break
                else:
                    demo_sent = list(demo_sent.strip())
                    demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                    tag = model.demo_one(sess, demo_data)
                    PER, LOC, ORG = get_entity(tag, demo_sent)

                    print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))

    elif args.mode == 'savemodel':

        ckpt_file = tf.train.latest_checkpoint(model_path)
        print(ckpt_file)
        paths['model_path'] = ckpt_file
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            saver.restore(sess, ckpt_file)
            demo_sent = tf.placeholder(tf.string, name='input')
            demo_sent = list(str(demo_sent).strip())
            demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            tag = model.demo_one(sess, demo_data)
            PER, LOC, ORG = get_entity(tag, demo_sent)
            result = {'PER': PER, 'LOC': LOC, 'ORG': ORG}
            print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
            # #保存SavedModel模型
            builder = tf.saved_model.builder.SavedModelBuilder('./savemodels')
            signature = predict_signature_def(inputs={'input': demo_sent},
                                              outputs={'output': result})
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={'predict': signature})
            builder.save()
            print('savemodel saves')
Exemple #44
0
 def handle_GET(self, request, context, scheme, value):
     entity = get_entity(scheme, value)
     return super(NearbyEntityListView, self).handle_GET(request, context, entity)
Exemple #45
0
 file_name_list = file_name('../data/test2')#读取测试集
 output_path = '../data/test_ori'
 if not os.path.exists(output_path): os.makedirs(output_path)
 for name in file_name_list:
     name_dir =  '../data/test2/' + name
     write_dir = '../data/test_ori/rec_' + name#预测完后写入的文件
     f = open(name_dir, 'r', encoding = 'utf-8')
     fw = open(write_dir, 'w', encoding = 'utf-8')
     while True:
         line = f.readline()
         if line == '':
             break
         line = list(line.strip())
         demo_data = [(line, ['O'] * len(line))]
         tag = model.demo_one(sess, demo_data)
         result = get_entity(tag, line)
         #result = sorted(result1[0].items(), key=lambda x:x[1], reverse = False)				
     entity_list = result
     #print(entity_list)
     #j_sort = []
     
     i = 0
     for en in entity_list:
         en_write1 = en['start']
         en_write2 = en['end']
         en_write3 = en['type']
         en_write4 = en['word']
         i += 1
         fw.write(str(en_write1)+'\t'+str(en_write2)+'\t'+en_write3+'\t'+en_write4+'\n')
     print(entity_list)
     fw.close()
Exemple #46
0
## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        #使用 saver.restore() 方法,重载模型的参数,继续训练或用于测试数据。
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                entities = get_entity(tag, demo_sent)
                print({i: entities[i] for i in entities.keys()})
                #print('PER: {}\nLOC: {}\nORG: {}\nTIME: {}\nROLE: {}'.format(PER, LOC, ORG, TIME, ROLE))
Exemple #47
0
    model.test(test_data)

## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Exemple #48
0
## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       dictname2id,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(sentence2id(demo_sent,
                                          word2id), [0] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                res = get_entity(tag[0], demo_sent, dictname2id)
                print(res)