예제 #1
0
def RunModel(saved_model_dir, signature_def_key, tag, text, ngrams_list=None):
    saved_model = reader.read_saved_model(saved_model_dir)
    meta_graph = None
    for meta_graph_def in saved_model.meta_graphs:
        if tag in meta_graph_def.meta_info_def.tags:
            meta_graph = meta_graph_def
            break
    if meta_graph_def is None:
        raise ValueError("Cannot find saved_model with tag" + tag)
    signature_def = signature_def_utils.get_signature_def_by_key(
        meta_graph, signature_def_key)
    text = text_utils.TokenizeText(text)
    ngrams = None
    if ngrams_list is not None:
        ngrams_list = text_utils.ParseNgramsOpts(ngrams_list)
        ngrams = text_utils.GenerateNgrams(text, ngrams_list)
    example = inputs.BuildTextExample(text, ngrams=ngrams)
    example = example.SerializeToString()
    inputs_feed_dict = {
        signature_def.inputs["inputs"].name: [example],
    }
    if signature_def_key == "proba":
        output_key = "scores"
    elif signature_def_key == "embedding":
        output_key = "outputs"
    else:
        raise ValueError("Unrecognised signature_def %s" % (signature_def_key))
    output_tensor = signature_def.outputs[output_key].name
    with tf.Session() as sess:
        loader.load(sess, [tag], saved_model_dir)
        outputs = sess.run(output_tensor, feed_dict=inputs_feed_dict)
        return outputs
예제 #2
0
def Request(text):
    example = inputs.BuildTextExample(text_utils.TokenizeText(text))
    request = classification_pb2.ClassificationRequest()
    request.model_spec.name = 'default'
    request.model_spec.signature_name = FLAGS.signature_def
    request.input.example_list.examples.extend([example])
    return request
예제 #3
0
def Request(text, ngrams):
    text = text_utils.TokenizeText(text)
    ngrams = None
    if ngrams is not None:
        ngrams_list = text_utils.ParseNgramsOpts(ngrams)
        ngrams = text_utils.GenerateNgrams(text, ngrams_list)
    example = inputs.BuildTextExample(text, ngrams=ngrams)
    request = classification_pb2.ClassificationRequest()
    request.model_spec.name = 'default'
    request.model_spec.signature_name = 'proba'
    request.input.example_list.examples.extend([example])
    return request
예제 #4
0
def WriteExamples(examples, outputfile, num_shards):
    """Write examles in TFRecord format.
    Args:
      examples: list of feature dicts.
                {'text': [words], 'label': [labels]}
      outputfile: full pathname of output file
    """
    shard = 0
    num_per_shard = len(examples) / num_shards + 1
    for n, example in enumerate(examples):
        if n % num_per_shard == 0:
            shard += 1
            writer = tf.python_io.TFRecordWriter(outputfile + '-%d-of-%d' % \
                                                 (shard, num_shards))
        record = inputs.BuildTextExample(example["text"], example["label"])
        writer.write(record.SerializeToString())