def _trim_and_decode(ids):
    """Trim EOS and PAD tokens from ids, and decode to return a string."""
    subtokenizer = Subtokenizer(FLAGS.vocab_file)
    try:
        index = list(ids).index(tokenizer.EOS_ID)
        return subtokenizer.decode(ids[:index])
    except ValueError:  # No EOS found in sequence
        return subtokenizer.decode(ids)
Beispiel #2
0
def eval_func(infer_graph, iteration=-1):
    if isinstance(infer_graph, tf.compat.v1.GraphDef):
        graph = tf.Graph()
        with graph.as_default():
            tf.import_graph_def(infer_graph, name='')
        infer_graph = graph

    subtokenizer = Subtokenizer(FLAGS.vocab_file)
    input_tensor = infer_graph.get_tensor_by_name('input_tensor:0')
    output_tensor = infer_graph.get_tensor_by_name(\
        'model/Transformer/strided_slice_19:0')
    ds = Dataset(FLAGS.inputs_file, FLAGS.reference_file, FLAGS.vocab_file)
    from lpot.data import DATALOADERS
    dataloader = DATALOADERS['tensorflow'](ds,
                                           batch_size=FLAGS.batch_size,
                                           collate_fn=collate_fn)
    config = tf.compat.v1.ConfigProto()
    config.use_per_session_threads = 1
    config.inter_op_parallelism_threads = 1
    sess = tf.compat.v1.Session(graph=infer_graph, config=config)
    time_list = []
    bleu_eval = bleu()
    predictions = []
    labels = []
    warmup = 10
    if iteration != -1:
        assert iteration >= warmup, 'iteration must be larger than warmup'
    for idx, (input_data, label) in enumerate(dataloader):
        if idx < iteration or iteration == -1:
            time_start = time.time()
            out = sess.run([output_tensor], {input_tensor: input_data})
            duration = time.time() - time_start
            time_list.append(duration)
            predictions.append(out)
            labels.extend(label)
        else:
            break
    latency = np.array(time_list[warmup:]).mean() / FLAGS.batch_size
    print('Batch size = {}'.format(FLAGS.batch_size))
    print('Latency: {:.3f} ms'.format(latency * 1000))
    print('Throughput: {:.3f} items/sec'.format(1. / latency))

    # only calculate accuracy when running out all predictions
    if iteration == -1:
        decode = []
        for i, tr in enumerate(predictions):
            for j, itr in enumerate(tr):
                for k, otr in enumerate(itr):
                    try:
                        index = list(otr).index(tokenizer.EOS_ID)
                        decode.append(subtokenizer.decode(otr[:index]))
                    except:
                        decode.append(subtokenizer.decode(otr))
        bleu_eval.update(decode, labels)
        print('Accuracy is {:.3f}'.format(bleu_eval.result()))
        return bleu_eval.result()