def run_model(tfrecord_output, original_dir, input_para, network_setting): """ 训练模型使用`train_image_classifier.py`中的main()函数进行训练 使用一个模型对tfrecord_ouput下所有待训练数据进行训练,有多少个不同的tfrecord数据就训练多少次, 模型保存在input_para['train_dir']下 Args: tfrecord_output: 待训练的tfrecord格式文件 original_dir: tfrecord数据的原目录 input_para: typr(dict) network_setting: type(dict), 一个模型的信息 Returns: None """ input_para = input_para.copy() tfrecord_files = os.listdir(tfrecord_output) for k, v in network_setting.items(): input_para[k] = v tmp_train_dir = os.path.join(input_para['train_dir'], network_setting['model_name']) api.mkdirs(tmp_train_dir) for s in tfrecord_files: print("[INFO] Use model %s, training on data %s" % (network_setting['model_name'], s)) tmp_data_original = os.path.join(original_dir, str(s), 'train') train_size = len(api.get_files(tmp_data_original)) tmp_dataset_dir = os.path.join(tfrecord_output, str(s)) tmp_class_train_dir = os.path.join(tmp_train_dir, str(s)) api.mkdirs([tmp_dataset_dir, tmp_class_train_dir]) input_para['dataset_dir'] = tmp_dataset_dir input_para['train_dir'] = tmp_class_train_dir input_para['split_to_size']['train'] = train_size train_main(input_para)
def data_split(input_dir, output_dir, label_path, k_fold=5): """ 划分训练集 Args: input_dir: 待划分的数据集路径 output_dir: 划分保存目录,按照k-fold数量生成不同文件夹, 并在每个文件夹下面生成train和test文件夹 label_path: 带划分数据集对应标签 pic.jpeg class_name pic.jpeg class_name ... k_fold: 划分数量 Returns: input_dir/0/train/pic_0.jpeg /0/train/pic_1.jpeg /0/test/pic_2.jpeg /0/test/pic_3.jpeg ... /1/train/pic_0.jpeg /1/train/pic_1.jpeg /1/test/pic_0.jpeg /1/test/pic_1.jpeg ... /k_fold/train/jpeg /k_fold/test/jpeg """ # 划分训练集 print("split data...") api.mkdirs(output_dir) labels_dict = disposal_data.getLabelsDict(label_path) disposal_data.disposal(input_dir, output_dir, labels_dict, k_fold=k_fold)
def convert_model(train_dir, test_data_dir, tfrecord_output, network_setting, model_save_para, input_para): """ 将训练好的模型转换为graph文件,方便数据预测 `export_inference_graph.py`中的main()函数定义输出输入接口 `freeze_graph.py`中的main保存最终的graph文件. 保存目录 mpdele_save_para['graph_dir']/network_setting['model_name'] Args: train_dir: 保存训练模型的总路径 test_dsat_dir: 原格式的待预测数据路径 tfrecord_output: tfrecord格式的待预测数据路径 network_setting: type(dict) model_save_para: type(dict) input_para: typr(dict) Returns: None """ model_save_para = model_save_para.copy() model_save_para['model_name'] = network_setting['model_name'] model_save_para['default_image_size'] = network_setting['train_image_size'] model_save_para['dataset_name'] = input_para['dataset_name'] model_save_para['labels_offset'] = input_para['labels_offset'] model_save_para['data_split'] = input_para['test_split_name'] tmp_train_dir = os.path.join(train_dir, network_setting['model_name']) model_index = os.listdir(tmp_train_dir) tmp_test_data_dir = os.path.join(test_data_dir, network_setting['model_name']) graph_dir = os.path.join(model_save_para['graph_dir'], network_setting['model_name']) api.mkdirs(graph_dir) for s in model_index: tmp_test_class_dir = os.path.join(tmp_test_data_dir, str(s), 'test') tmp_test_data_size = len(api.get_files(tmp_test_class_dir)) model_save_para['split_to_size']['test'] = tmp_test_data_size tmp_graph_dir = os.path.join(graph_dir, str(s)) api.mkdirs(tmp_graph_dir) model_save_para['graph_dir'] = os.path.join(tmp_graph_dir, 'inf_graph.pb') model_save_para['dataset_dir'] = os.path.join(tfrecord_output, str(s)) export_graph_main(model_save_para) tmp_model_ckpt = os.path.join(tmp_train_dir, str(s)) model_save_para['input_checkpoint'] = api.get_checkpoint( tmp_model_ckpt) model_save_para['frozen_graph'] = os.path.join(tmp_graph_dir, 'frozen_graph.pb') model_save_para['output_node_names'] = network_setting[ 'output_tensor_name'] freeze_graph_main(model_save_para)
def prediction_train_data(graph_dir, test_dir, label_path, prediction_para, network_setting): """ 对test_dir下的所有数据进行预测并将结果保存在 prediction_para['prediction_output']/network_setting['model_name']下 Args: graph_dir: 所有graph的路径 test_dir: 待预测数据路径 label_path: 待预测数据的标签路径 prediction_para: type(dict) network_setting: type(dict) Returns: None """ prediction_ret = {} prediction_para = prediction_para.copy() graph_index_dir = os.path.join(graph_dir, network_setting['model_name']) class_labels = os.listdir(graph_index_dir) prediction_para['label_path'] = label_path tmp_model_prediction = os.path.join(prediction_para['prediction_output'], network_setting['model_name']) for c in class_labels: test_data_dir = os.path.join(test_dir, c, 'test') prediction_para['image_file'] = test_data_dir model_path = os.path.join(graph_index_dir, c, 'frozen_graph.pb') prediction_para['model_path'] = model_path prediction_para[ 'tensor_name'] = network_setting['output_tensor_name'] + ":0" prediction_para['width'] = network_setting['train_image_size'] prediction_para['height'] = network_setting['train_image_size'] prediction_dir = os.path.join(tmp_model_prediction, c) api.mkdirs(prediction_dir) prediction_para['prediction_output'] = os.path.join( prediction_dir, 'prediction.npy') # 创建文件夹. tmp_prediction = run_inference_on_image(prediction_para) prediction_ret = dict(prediction_ret, **tmp_prediction) return prediction_ret
def data_convert_to_tfrecord(train_dir, tfrecord_output, input_para): """ 数据转换为TFRecord格式 Args: train_dir: 需要转换的数据集路径 tfrecord_output: 输出路径 input_para: type(dict) Returns: tfrecod_dir/`input_para['dataset_name']`_`input_para['train_split_name']`_\ `range(input_para['tf_num_shards'])`_of_\ `input_para['tf_num_shards']`.tfrecord """ split_entries = os.listdir(train_dir) for s in split_entries: sub_dir = os.path.join(train_dir, str(s)) sub_tfrecord = os.path.join(tfrecord_output, str(s)) api.mkdirs(sub_tfrecord) train_data_dir = os.path.join(sub_dir, 'train') labels_file = os.path.join(sub_tfrecord, 'labels.txt') data_convert.get_class_labels(train_data_dir, labels_file) validation_data_dir = os.path.join(sub_dir, 'test') tfrecord.process_dataset(input_para['train_split_name'], train_data_dir, input_para['tf_num_shards'], labels_file, input_para['tf_num_threads'], sub_tfrecord, input_para['dataset_name'], input_para['tf_class_label_base']) tfrecord.process_dataset(input_para['test_split_name'], validation_data_dir, input_para['tf_num_shards'], labels_file, input_para['tf_num_threads'], sub_tfrecord, input_para['dataset_name'], input_para['tf_class_label_base'])