示例#1
0
 def _callback(result_future):
     exception = result_future.exception()
     if exception:
         print exception
     else:
         results = tf.contrib.util.make_ndarray(result_future.result().outputs['outputs'])
         for result in results[0]:
             print ndarray_to_text(result)
     event.set()
def main(_):

    start = stopwatch()
    initialize_globals()

    if len(FLAGS.one_shot_infer):
        #load the frozen graph as in train(...) or as in https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
        with tf.gfile.FastGFile("../../models/output_graph.pb", 'rb') as fin:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(fin.read())
        with tf.Graph().as_default() as pretrained_model:
            tf.import_graph_def(graph_def, name="pretrained_")
        """
        for op in pretrained_model.get_operations():
            print(op.name)
        """

        #        print("------------***-------------")

        # https://stackoverflow.com/questions/36883949/in-tensorflow-get-the-names-of-all-the-tensors-in-a-graph?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
        lstTensors = [op.values() for op in pretrained_model.get_operations()]
        input_node = lstTensors[0]
        input_lengths = lstTensors[1]
        output_node = lstTensors[-1]
        """
        print("input node name: ")
        print(input_node[0].name)
        print("input node shape: ")
        print(input_node[0].shape)#V IMP: shape of input node is [x,y,z] where x = batch_size. For one shot infer, batch_size = 1
        print("input lengths name: ")
        print(input_lengths[0].name)
        print("input lengths shape: ")
        print(input_lengths[0].shape) #V IMP: shape of input_lengths node is [x,y] where x = batch_size. For one shot infer, batch_size = 1
        print("output node name: ")
        print(output_node[0].name)
        print("output node shape: ")
        print(output_node[0].shape)
        """

        #        do_single_file_inference(FLAGS.one_shot_infer)
        #        print("n_input = "+repr(n_input))
        #        print("n_context = "+repr(n_context))
        mfcc = audiofile_to_input_vector(FLAGS.one_shot_infer, n_input,
                                         n_context)
        #        print(mfcc.shape)

        #        output_node = pretrained_model.get_tensor_by_name(pretrained_model.get_operations()[-1].name)

        batch_size = 1
        with tf.Session(graph=pretrained_model) as sess:
            output = sess.run(
                output_node,
                feed_dict={
                    input_node:
                    [mfcc.reshape((batch_size, mfcc.shape[0], mfcc.shape[1]))],
                    input_lengths:
                    [np.array(len(mfcc)).reshape((batch_size, ))]
                })
            #            print(output)
            text = ndarray_to_text(output[0][0][0], alphabet)
            print("\n\nResult:")
            print(text)
    else:
        print(
            "Correct usage: python3 _this.py --one_shot_infer <<path-of-input-wav-file>>"
        )

    delta = stopwatch(start)
    print("Net execution time including loading of the graph = " +
          format_duration(delta))