예제 #1
0
def main(_):
    # Build the bridge with intel model zoo configuration for inference run
    config_file_path = os.path.join(
        os.path.dirname(os.path.realpath(__file__)), "../config/models.json")
    model = ModelZooBridge(args.model, args.in_graph, args.data_location,
                           args.models_zoo_location, args.models_source_dir,
                           config_file_path)

    # pick up the quantization calibration parameters from models.json
    inference_cmd_gen_minmax_log = model.inference_calib_cmd
    graph_converter_params = model.quantize_params_dict
    inputs = []
    outputs = []
    excluded_ops = []
    excluded_nodes = []
    per_channel = False
    if ModelZooBridge.INPUT_NODE_LIST in graph_converter_params.keys():
        inputs = graph_converter_params[ModelZooBridge.INPUT_NODE_LIST]
    if ModelZooBridge.OUTPUT_NODE_LIST in graph_converter_params.keys():
        outputs = graph_converter_params[ModelZooBridge.OUTPUT_NODE_LIST]
    if ModelZooBridge.EXCLUDED_OPS_LIST in graph_converter_params.keys():
        excluded_ops = graph_converter_params[ModelZooBridge.EXCLUDED_OPS_LIST]
    if ModelZooBridge.EXCLUDED_NODE_LIST in graph_converter_params.keys():
        excluded_nodes = graph_converter_params[
            ModelZooBridge.EXCLUDED_NODE_LIST]
    if ModelZooBridge.PER_CHANNEL_FLAG in graph_converter_params.keys():
        per_channel = graph_converter_params[ModelZooBridge.PER_CHANNEL_FLAG]

    # Call the GraphConverter to do the FP32 calibration, INT8 quantization, and INT8 calibration
    qt = converter.GraphConverter(args.in_graph, args.out_graph, inputs,
                                  outputs, excluded_ops, excluded_nodes,
                                  per_channel)
    qt.debug = args.debug
    qt.gen_calib_data_cmds = inference_cmd_gen_minmax_log
    qt.convert()
예제 #2
0
def main(_):
    c = None

    per_channel_value = False
    output_shape = args.model + "/predictions/Reshape_1"
    image_size=224

    if args.model == 'inception_v1':
        output_shape = 'InceptionV1/Logits/Predictions/Reshape_1'
    
    elif args.model == 'inception_v2':
        output_shape = 'InceptionV2/Predictions/Reshape_1'

    elif args.model == 'inception_v4':
        output_shape = 'InceptionV4/Logits/Predictions'
        image_size=299

    elif args.model == 'mobilenet_v1':
        per_channel_value = True
        output_shape = 'MobilenetV1/Predictions/Reshape_1'

    elif args.model == 'mobilenet_v2':
        per_channel_value = True
        output_shape = 'MobilenetV2/Predictions/Reshape_1'

    elif args.model == 'vgg_16':
        output_shape = 'vgg_16/fc8/squeezed'

    elif args.model == 'vgg_19':
        output_shape = 'vgg_19/fc8/squeezed'

    elif args.model == 'nasnet_large' or args.model == 'pnasnet_large':
        output_shape = 'final_layer/predictions'
        image_size=331

    if per_channel_value:
        c = converter.GraphConverter(args.model_location, args.out_graph, ['input'], [output_shape],
                                     per_channel=True)
    else:
        c = converter.GraphConverter(args.model_location, args.out_graph, ['input'], [output_shape])
    
    c.debug = True
    c.gen_calib_data_cmds = model_callback_cmds(args.data_location,output_shape,image_size)
    c.convert()
예제 #3
0
def main(_):

    print(args.inputs.split(','), args.outputs.split(','), args.output_graph)
    if not os.path.exists(args.input_graph):
        print("{} doesn't exist!".format(args.input_graph))
        sys.exit(-1)

    if args.inputs:
        inputs = args.inputs.split(',')
    else:
        inputs = []

    if args.outputs:
        outputs = args.outputs.split(',')
    else:
        outputs = []

    if args.excluded_ops:
        excluded_ops = args.excluded_ops.split(',')
    else:
        excluded_ops = []

    if args.excluded_nodes:
        excluded_nodes = args.excluded_nodes.split(',')
    else:
        excluded_nodes = []

    qt = converter.GraphConverter(args.input_graph, args.output_graph, inputs,
                                  outputs, excluded_ops, excluded_nodes,
                                  args.per_channel)
    qt.debug = args.debug
    if 'input_graph=' in args.callback:
        prefix = args.callback.split('input_graph=')[0]
        postfix = ' '.join(
            args.callback.split('input_graph=')[-1].split(' ')[1:])
        callback_cmd = prefix + 'input_graph={} ' + postfix
    else:
        callback_cmd = args.callback
    qt.gen_calib_data_cmds = callback_cmd
    qt.convert()
예제 #4
0
def quantization(input_graph, model_path, tf_records):
    minigo_converter = converter.GraphConverter(
        input_graph, model_path + '.pb', ["pos_tensor"],
        ["policy_output", "value_output"])
    minigo_converter.gen_calib_data_cmds = minigo_callback_cmds(tf_records)
    minigo_converter.quantize()