from model_meta import NETS
import os
import subprocess
import pdb

TEST_IMAGE_PATH = 'data/images/list.txt'
TEST_OUTPUT_PATH = 'data/test_output_trt.txt'
TEST_EXE_PATH = './build/src/test/test_trt'

if __name__ == '__main__':

    # delete output file
    if os.path.isfile(TEST_OUTPUT_PATH):
        os.remove(TEST_OUTPUT_PATH)

    for net_name, net_meta in NETS.items():
        if 'exclude' in net_meta.keys() and net_meta['exclude'] is True:
            continue

        args = [
            TEST_IMAGE_PATH,
            net_meta['plan_filename'],
            net_meta['input_name'],
            str(net_meta['input_height']),
            str(net_meta['input_width']),
            net_meta['output_names'][0],
            str(net_meta['num_classes']),
            net_meta['preprocess_fn'].__name__,
            str(50),  # numRuns
            "float",  # dataType 
            str(32),  # maxBatchSize 
Example #2
0
                                                                                       'high_performance',
                                                                                       'power_saver',
                                                                                       'system_settings',
                                                                                       'sustained_high_performance',
                                                                                       'burst',
                                                                                       'default'])

args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    with output_manager(args.output_file, 'w+') as output:
        output.write("net_name,throughput\n")

    for net_name, net_meta in sorted(NETS.items()):
        if 'exclude' in net_meta.keys() and net_meta['exclude'] is True:
            logging.info("Skipping {}".format(net_name))
            continue
    
        logging.info('Testing %s' % net_name)
        throughput = (test_tf_average_throughput(net_meta, args.num_runs)
                      if args.net_type == 'tf' else
                      test_trt_average_throughput(net_meta, args.data_type, args.num_runs)
                      if args.net_type == 'trt' else
                      test_snpe_average_throughput(net_meta, args.runtime, args.num_runs, args.performance_profile))
    
        csv_result = '{},{}'.format(net_name, throughput)
        logging.info(csv_result)
        with output_manager(args.output_file, 'a') as output:
            output.write(csv_result+'\n')