Example #1
0
def gen_random_input(model_output_dir,
                     input_nodes,
                     input_shapes,
                     input_files,
                     input_ranges,
                     input_data_types,
                     input_file_name="model_input"):
    for input_name in input_nodes:
        formatted_name = common.formatted_file_name(
            input_file_name, input_name)
        if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
            sh.rm("%s/%s" % (model_output_dir, formatted_name))
    input_nodes_str = ",".join(input_nodes)
    input_shapes_str = ":".join(input_shapes)
    input_ranges_str = ":".join(input_ranges)
    input_data_types_str = ",".join(input_data_types)
    generate_input_data("%s/%s" % (model_output_dir, input_file_name),
                        input_nodes_str,
                        input_shapes_str,
                        input_ranges_str,
                        input_data_types_str)

    input_file_list = []
    if isinstance(input_files, list):
        input_file_list.extend(input_files)
    else:
        input_file_list.append(input_files)
    if len(input_file_list) != 0:
        input_name_list = []
        if isinstance(input_nodes, list):
            input_name_list.extend(input_nodes)
        else:
            input_name_list.append(input_nodes)
        if len(input_file_list) != len(input_name_list):
            raise Exception('If input_files set, the input files should '
                            'match the input names.')
        for i in range(len(input_file_list)):
            if input_file_list[i] is not None:
                dst_input_file = model_output_dir + '/' + \
                                 common.formatted_file_name(input_file_name,
                                                            input_name_list[i])
                if input_file_list[i].startswith("http://") or \
                        input_file_list[i].startswith("https://"):
                    six.moves.urllib.request.urlretrieve(input_file_list[i],
                                                         dst_input_file)
                else:
                    sh.cp("-f", input_file_list[i], dst_input_file)
Example #2
0
def gen_random_input(model_output_dir,
                     input_nodes,
                     input_shapes,
                     input_files,
                     input_ranges,
                     input_data_types,
                     input_file_name="model_input"):
    for input_name in input_nodes:
        formatted_name = common.formatted_file_name(
            input_file_name, input_name)
        if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
            sh.rm("%s/%s" % (model_output_dir, formatted_name))
    input_nodes_str = ",".join(input_nodes)
    input_shapes_str = ":".join(input_shapes)
    input_ranges_str = ":".join(input_ranges)
    input_data_types_str = ",".join(input_data_types)
    generate_input_data("%s/%s" % (model_output_dir, input_file_name),
                        input_nodes_str,
                        input_shapes_str,
                        input_ranges_str,
                        input_data_types_str)

    input_file_list = []
    if isinstance(input_files, list):
        input_file_list.extend(input_files)
    else:
        input_file_list.append(input_files)
    if len(input_file_list) != 0:
        input_name_list = []
        if isinstance(input_nodes, list):
            input_name_list.extend(input_nodes)
        else:
            input_name_list.append(input_nodes)
        if len(input_file_list) != len(input_name_list):
            raise Exception('If input_files set, the input files should '
                            'match the input names.')
        for i in range(len(input_file_list)):
            if input_file_list[i] is not None:
                dst_input_file = model_output_dir + '/' + \
                        common.formatted_file_name(input_file_name,
                                                   input_name_list[i])
                if input_file_list[i].startswith("http://") or \
                        input_file_list[i].startswith("https://"):
                    urllib.urlretrieve(input_file_list[i], dst_input_file)
                else:
                    sh.cp("-f", input_file_list[i], dst_input_file)
Example #3
0
def gen_input(model_output_dir,
              input_nodes,
              input_shapes,
              input_files=None,
              input_ranges=None,
              input_data_types=None,
              input_data_map=None,
              input_file_name="model_input"):
    for input_name in input_nodes:
        formatted_name = common.formatted_file_name(input_file_name,
                                                    input_name)
        if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
            sh.rm("%s/%s" % (model_output_dir, formatted_name))
    input_file_list = []
    if isinstance(input_files, list):
        input_file_list.extend(input_files)
    else:
        input_file_list.append(input_files)
    if input_data_map:
        for i in range(len(input_nodes)):
            dst_input_file = model_output_dir + '/' + \
                             common.formatted_file_name(input_file_name,
                                                        input_nodes[i])
            input_name = input_nodes[i]
            common.mace_check(
                input_name in input_data_map, common.ModuleName.RUN,
                "The preprocessor API in PrecisionValidator"
                " script should return all inputs of model")
            if input_data_types[i] == 'float32':
                input_data = np.array(input_data_map[input_name],
                                      dtype=np.float32)
            elif input_data_types[i] == 'int32':
                input_data = np.array(input_data_map[input_name],
                                      dtype=np.int32)
            else:
                common.mace_check(
                    False, common.ModuleName.RUN,
                    'Do not support input data type %s' % input_data_types[i])
            common.mace_check(
                list(map(int, common.split_shape(input_shapes[i]))) == list(
                    input_data.shape), common.ModuleName.RUN,
                "The shape return from preprocessor API of"
                " PrecisionValidator script is not same with"
                " model deployment file. %s vs %s" %
                (str(input_shapes[i]), str(input_data.shape)))
            input_data.tofile(dst_input_file)
    elif len(input_file_list) != 0:
        input_name_list = []
        if isinstance(input_nodes, list):
            input_name_list.extend(input_nodes)
        else:
            input_name_list.append(input_nodes)
        common.mace_check(
            len(input_file_list) == len(input_name_list),
            common.ModuleName.RUN,
            'If input_files set, the input files should '
            'match the input names.')
        for i in range(len(input_file_list)):
            if input_file_list[i] is not None:
                dst_input_file = model_output_dir + '/' + \
                                 common.formatted_file_name(input_file_name,
                                                            input_name_list[i])
                if input_file_list[i].startswith("http://") or \
                        input_file_list[i].startswith("https://"):
                    six.moves.urllib.request.urlretrieve(
                        input_file_list[i], dst_input_file)
                else:
                    sh.cp("-f", input_file_list[i], dst_input_file)
    else:
        # generate random input files
        input_nodes_str = ",".join(input_nodes)
        input_shapes_str = ":".join(input_shapes)
        input_ranges_str = ":".join(input_ranges)
        input_data_types_str = ",".join(input_data_types)
        generate_input_data("%s/%s" % (model_output_dir, input_file_name),
                            input_nodes_str, input_shapes_str,
                            input_ranges_str, input_data_types_str)