def normalize_model_config(conf, model_output=None, org_model_dir=None): conf = normalize_graph_config(conf, model_output, org_model_dir) if ModelKeys.subgraphs in conf: nor_subgraphs = {} if isinstance(conf[ModelKeys.subgraphs], list): nor_subgraph = normalize_graph_config(conf[ModelKeys.subgraphs][0], model_output, org_model_dir) conf[ModelKeys.input_tensors] = \ nor_subgraph[ModelKeys.input_tensors] conf[ModelKeys.output_tensors] = \ nor_subgraph[ModelKeys.output_tensors] set_default_config_value(nor_subgraph, conf) nor_subgraphs[ModelKeys.default_graph] = nor_subgraph else: for graph_name, subgraph in conf[ModelKeys.subgraphs].items(): nor_subgraph = normalize_graph_config(subgraph, model_output, org_model_dir) set_default_config_value(nor_subgraph, conf) nor_subgraphs[graph_name] = nor_subgraph conf[ModelKeys.subgraphs] = nor_subgraphs model_base_conf = copy.deepcopy(conf) del model_base_conf[ModelKeys.subgraphs] subgraphs = conf[ModelKeys.subgraphs] for net_name, subgraph in subgraphs.items(): net_conf = copy.deepcopy(model_base_conf) net_conf.update(subgraph) subgraphs[net_name] = net_conf MaceLogger.summary(conf) return conf
def normalize_model_config(conf): conf = copy.deepcopy(conf) if ModelKeys.subgraphs in conf: subgraph = conf[ModelKeys.subgraphs][0] del conf[ModelKeys.subgraphs] conf.update(subgraph) conf[ModelKeys.platform] = parse_platform(conf[ModelKeys.platform]) conf[ModelKeys.runtime] = parse_device_type(conf[ModelKeys.runtime]) if ModelKeys.quantize in conf and conf[ModelKeys.quantize] == 1: conf[ModelKeys.data_type] = mace_pb2.DT_FLOAT else: if ModelKeys.data_type in conf: conf[ModelKeys.data_type] = parse_internal_data_type( conf[ModelKeys.data_type]) else: conf[ModelKeys.data_type] = mace_pb2.DT_HALF # parse input conf[ModelKeys.input_tensors] = to_list(conf[ModelKeys.input_tensors]) conf[ModelKeys.input_tensors] = [ str(i) for i in conf[ModelKeys.input_tensors] ] input_count = len(conf[ModelKeys.input_tensors]) conf[ModelKeys.input_shapes] = [ parse_int_array(shape) for shape in to_list(conf[ModelKeys.input_shapes]) ] mace_check( len(conf[ModelKeys.input_shapes]) == input_count, "input node count and shape count do not match") input_data_types = [ parse_data_type(dt) for dt in to_list(conf.get(ModelKeys.input_data_types, ["float32"])) ] if len(input_data_types) == 1 and input_count > 1: input_data_types = [input_data_types[0]] * input_count mace_check( len(input_data_types) == input_count, "the number of input_data_types should be " "the same as input tensors") conf[ModelKeys.input_data_types] = input_data_types input_data_formats = [ parse_data_format(df) for df in to_list(conf.get(ModelKeys.input_data_formats, ["NHWC"])) ] if len(input_data_formats) == 1 and input_count > 1: input_data_formats = [input_data_formats[0]] * input_count mace_check( len(input_data_formats) == input_count, "the number of input_data_formats should be " "the same as input tensors") conf[ModelKeys.input_data_formats] = input_data_formats input_ranges = [ parse_float_array(r) for r in to_list(conf.get(ModelKeys.input_ranges, ["-1.0,1.0"])) ] if len(input_ranges) == 1 and input_count > 1: input_ranges = [input_ranges[0]] * input_count mace_check( len(input_ranges) == input_count, "the number of input_ranges should be " "the same as input tensors") conf[ModelKeys.input_ranges] = input_ranges # parse output conf[ModelKeys.output_tensors] = to_list(conf[ModelKeys.output_tensors]) conf[ModelKeys.output_tensors] = [ str(i) for i in conf[ModelKeys.output_tensors] ] output_count = len(conf[ModelKeys.output_tensors]) conf[ModelKeys.output_shapes] = [ parse_int_array(shape) for shape in to_list(conf[ModelKeys.output_shapes]) ] mace_check( len(conf[ModelKeys.output_tensors]) == output_count, "output node count and shape count do not match") output_data_types = [ parse_data_type(dt) for dt in to_list(conf.get(ModelKeys.output_data_types, ["float32"])) ] if len(output_data_types) == 1 and output_count > 1: output_data_types = [output_data_types[0]] * output_count mace_check( len(output_data_types) == output_count, "the number of output_data_types should be " "the same as output tensors") conf[ModelKeys.output_data_types] = output_data_types output_data_formats = [ parse_data_format(df) for df in to_list(conf.get(ModelKeys.output_data_formats, ["NHWC"])) ] if len(output_data_formats) == 1 and output_count > 1: output_data_formats = [output_data_formats[0]] * output_count mace_check( len(output_data_formats) == output_count, "the number of output_data_formats should be " "the same as output tensors") conf[ModelKeys.output_data_formats] = output_data_formats if ModelKeys.check_tensors in conf: conf[ModelKeys.check_tensors] = to_list(conf[ModelKeys.check_tensors]) conf[ModelKeys.check_shapes] = [ parse_int_array(shape) for shape in to_list(conf[ModelKeys.check_shapes]) ] mace_check( len(conf[ModelKeys.check_tensors]) == len( conf[ModelKeys.check_shapes]), "check tensors count and shape count do not match.") MaceLogger.summary(conf) return conf