def str_dict(self, d_str): try: dic = ast.literal_eval(d_str) return dic except Exception as e: log.error("String conversion dictionary faile {}".format(e)) return False
def create_first_dir(self,dataset_names,ids): for id in ids: if id in list(dataset_names.keys()): if not self.create_dir(config.config["path"]["firstdir_path"].format(dataset_names[id],id)): log.error("file {} create faile".format(dataset_names[id])) return False return True
def clear(self,path): try: shutil.rmtree(path) return True except Exception as e: log.error("del file error:{}".format(e)) return False
def create_dir(self,path): if not os.path.exists(path): try: os.mkdir(path) return True except Exception as e: log.error(e) return False return True
def test_model_info_get(self, test_model_id, session): try: test_model_info = session.query( Tables["dataset_test_model"]).filter_by( testing_model_id=test_model_id).all() return test_model_info[len(test_model_info) - 1] except Exception as e: log.error("test 0model info get fail {}".format(e)) return False
def image_save(self,dataset_names,id,response,tag,image): try: with open(config.config["path"]["image_path"].format(dataset_names[id],id,tag,image),"wb+") as f: f.write(base64.b64decode(response.json()["data"][tag][image])) f.close() except Exception as e: log.error("image:{} svase error {}".format(image,e)) return False return True
def dataset_path_get(self): try: self.cur_path = os.path.dirname(os.path.realpath(__file__)) self.cur_path = os.path.join(self.cur_path,config.config["path"]["datasetdir_path"].replace("./","")) return [os.path.join(self.cur_path,dir) for dir in os.listdir(config.config["path"]["datasetdir_path"]) if not ".zip" in dir] except Exception as e: log.error("dataset path get fail {}".format(e)) return False
def mid_train_switch(self, task_type, data): if task_type == 0: return model_train.mid_model_init_train(data) elif task_type == 1: return model_train.mid_model_opt_train(data) elif task_type == 2: return model_train.mid_model_ass_train(data) else: log.error("fail task_type") return False
def sendjsondata(self, params, key): try: parmas_message = json.dumps(params) producer = self.producer v = parmas_message.encode('utf-8') producer.send(self.kafkatopic, value=v, key=key.encode("utf-8")) producer.close() return True except KafkaError as e: log.error("kafka input data failed {}".format(e)) return False
def dataset_info_get(self, dataset_ids): data = [] self.session = self.DBSession() for id in dataset_ids: try: data.append(self.session.query(Tables["dataset_info"]).get(id)) except Exception as e: log.error("The dataset query failed {}".format(e)) return False self.session.close() return {id: name.dataset_name for id, name in zip(dataset_ids, data)}
def create_model_save_dir(self,path): if not os.path.exists(path): try: os.makedirs(path) return True except Exception as e: log.error("Model file save file folder creation failed {}".format(e)) return False else: print("Folder already exists") return True
def dowload_save_bytes(self, dataset_ids): for id in dataset_ids: if not self.reponse_get_bytes(id): log.error("dataset:{} dowload fail".format(id)) return False with open(config.config["path"]["datasetdir_paths"].format(id), "ab") as f: # f.write(self.response_bytes.content) for chunk in self.response_bytes.iter_content(chunk_size=1024): if chunk: f.write(chunk) return True
def mid_model_init_train(self, parameter): if config.config["gpu_control"]["open"] == "1": try: self.memory_size = config.config["memory_size"]["mm_size"] self.gpu_init(self.memory_size) except Exception as e: log.error("gpu error {}".format(e)) #数据增强器 配置文件 self.train_datagen = ImageDataGenerator( rescale=1. / 255, rotation_range=40, # 随机旋转角度的范围 width_shift_range=0.2, # 随机转换图片宽度的范围 height_shift_range=0.2, # 随机转换图片高度的范围 shear_range=0.2, # 随机剪切转换比例 zoom_range=0.2, # 随机放缩比例 horizontal_flip=True, # 开启水平翻转 fill_mode='nearest' # 填充策略 ) self.valid_datagen = ImageDataGenerator(rescale=1. / 255) image_height, image_width = self.mid_height_width( parameter) #320x480 h_w self.model_path = os.path.join(parameter["new_model_path"], parameter["train_model_name"]) self.result_path = os.path.join(self.model_path, "model_result") parameter.update({"model_fileaddr": self.result_path}) dir_operator.create_model_save_dir(self.result_path) self.train_dataset = "" self.vaild_dataset = "" #获取训练集的图片增强生成器 self.train_generator = train_datagen.flow_from_directory( self.train_dataset, target_size=(image_height, image_width), classes=parameter["target"].keys(), #数据库数据加载--》、 batch_size=parameter["batch_size"], #--》配置文件 shuffle=True, class_mode=config.config["parameter"]["class_mode"]) #写死 #获取测试集的图片增强生成器 self.valid_generator = valid_datagen.flow_from_directory( self.vaild_dataset, target_size=(image_height, image_width), classes=parameter["target"].keys(), batch_size=int( config.config["parameter"]["batch_size"]), #每批次训练的大小 shuffle=True, class_mode=config.config["parameter"]["class_mode"]) data.update({ "train_generator": self.train_generator, "valid_generator": self.valid_generator }) self.mid_package_program(parameter) return True
def updata_train_mdoel_status(self, model_id, status): self.session = self.DBSession() try: model_info = self.session.query( Tables["dataset_model_info"]).filter_by(id=model_id).all() model_info[0].training_status = status self.session.commit() self.session.close() return True except Exception as e: log.error("model:{} status change failed{}".format(model_id, e)) self.session.close() return False
def mid_merge_dataset(self,path_list): try: self.create_dir(config.config["path"]["all_dataset_dir_path"]) for dateset_path in path_list: for tag in os.listdir(dateset_path): print(tag) self.create_dir(os.path.join(config.config["path"]["all_dataset_dir_path"],tag)) for image_file in os.listdir(os.path.join(dateset_path,tag)): print(os.path.join(dateset_path,tag,image_file)) shutil.copyfile(os.path.join(dateset_path,tag,image_file),os.path.join(config.config["path"]["all_dataset_dir_path"],tag,image_file)) return True except Exception as e: log.error("dataset merge error {}".format(e)) return False
def train_result_save_switch(self, task_type, data, node): if task_type is 0 or task_type is 1: if not dbase.insert_init_opt_result_save(data): log.error("init opt result save error {}".format(node)) self.update_status(data, 4, node, data["train_model_id"]) return False return True elif task_type is 2: if not dbase.insert_ass_result_save(data): log.error("ass result save error {}".format(node)) self.update_status(data, 4, node, data["train_model_id"]) return False return True else: return False
def mid_create_model(self, parameter): try: image_height, image_width = self.mid_height_width(parameter) dd = [False, None, (image_height, image_width, 3)] self.c_conv_base = self.mid_reflect("keras.applications.resnet", "ResNet50", dd) # self.c_conv_base = keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet',input_shape=(image_height, image_width,3)) self.x = self.c_conv_base.output self.x = Flatten()(self.x) self.x = Dense(64, activation="relu")(self.x) self.x = Dense(len(parameter["target"]), activation="softmax")(self.x) self.c_model = Model(inputs=self.c_conv_base.input, outputs=self.x) return True, self.c_model except Exception as e: log.error("create model error {}".format(e)) return False, None
def reponse_get_bytes(self, id): self.headers = { 'content-type': 'application/json', "authorization": "Bearer wYlvTrIhhvMpfBC4cW1Wdef4FoNgJt" } try: self.response_bytes = requests.get( config.config["url"]["dataset_url"].format(id), headers=self.headers, stream=True) if self.response_bytes.status_code > 200 and self.response_bytes.status_code < 200: log.error("dataset_{} response faile".format(id)) return False return True except Exception as e: log.error("requests error %s" % e) return False
def model_info_get(self, model_id, model_version): self.session = self.DBSession() now_mdoel_info = self.session.query( Tables["dataset_model_info"]).get(model_id) if now_mdoel_info.model_status is 1: last_version = None if model_version is None: try: last_version = self.session.query( Tables["dataset_train_model"]).filter_by( training_model_id=model_id).order_by( Tables["dataset_train_model"].id).all() # if len(last_version) is 0: last_version = None else: last_version = last_version[len(last_version) - 1] n_version_info = None self.session.close() return now_mdoel_info, n_version_info, last_version except Exception as e: self.session.close() log.error("{} model:{} info get error".format(e, model_id)) return False else: try: last_version = self.session.query( Tables["dataset_train_model"]).filter_by( training_model_id=model_id).order_by( Tables["dataset_train_model"].id).all() # n_version_info = self.session.query( Tables["dataset_train_model"]).filter_by( training_model_id=model_id, model_version=model_version).all() self.session.close() return now_mdoel_info, n_version_info[0], last_version[ len(last_version) - 1] #name,address,last_version except Exception as e: log.error("{} model:{} info get error".format(e, model_id)) self.session.close() return False
def un_zip(self,zip_file_name): self.cur_path = os.path.dirname(os.path.realpath(__file__)) d_path = config.config["path"]["datasetdir_path"].replace("./","") self.zip_path = os.path.join(self.cur_path,d_path,zip_file_name) self.tar_path = os.path.join(self.cur_path,d_path) z = zipfile.ZipFile(self.zip_path,"r") for k in z.infolist(): try: save_path = (self.tar_path+"/"+k.filename).replace("\\","/") if not os.path.exists(save_path): if not "." in save_path: os.mkdir(save_path) if "." in save_path: image_file = z.read(k) with open(save_path,"wb") as f: f.write(image_file) except Exception as e: log.error("image unzip fail {}".format(e)) return False z.close() return True
def insert_init_opt_result_save(self, data): # try: # n_traininginfo_id = self.model_info_get(data["training_model_id"],None)[2].id # except Exception as e: # log.error("model info get error {}".format(e)) # return False self.session = self.DBSession() insert_data = Tables["dataset_train_model"]( training_time=data["training_time"], accuracy_rate=data["accuracy_rate"], model_version=data["model_version"], model_fileaddr=data["model_fileaddr"], dataset_list=data["dataset_list"], training_model_id=data["training_model_id"]) try: self.session.add(insert_data) self.session.commit() try: n_traininginfo_id = self.model_info_get( data["training_model_id"], None)[2].id except Exception as e: log.error("model info get error {}".format(e)) return False for da in data["tags_index"]: insert_da = Tables["dataset_train_index"]( F1score=da["F1score"], Gscore=da["Gscore"], precision_rate=da["precision_rate"], recall_rate=da["recall_rate"], label_id=da["label_id"], traininginfo_id=n_traininginfo_id) self.session.add(insert_da) self.session.commit() self.session.close() return True except Exception as e: log.error("data insert error {}".format(e)) self.session.close() return False
def mid_package_program(self, parameter): try: image_height, image_width = self.mid_height_width(parameter) self.flag, self.m_model = self.mid_create_model( parameter) #这个地方后面要修改 print("-" * 10) self.m_loss = self.mid_reflect(parameter["loss_value"], "SparseCategoricalCrossentropy", [])() #损失函数这个返回的是类加()获取对象 self.m_optimizer = self.mid_reflect( parameter["optimizer"], "RMSprop", [float(config.config["parameter"]["optimizer_lr"]) ]) #优化器也是返回了类但是不要对象 self.m_model.compile(loss=self.m_loss, optimizer=self.m_optimizer, metrics=['accuracy']) self.train_num = parameter["train_generator"].samples // int( config.config["parameter"]["batch_size"]) self.valid_num = parameter["valid_generator"].samples // int( config.config["parameter"]["batch_size"]) self.model_path = os.path.join(parameter["new_model_path"], parameter["train_model_name"]) self.result_path = os.path.join(self.model_path, "model_result") self.m_callbacks = self.mid_model_callback(self.model_path, self.result_path) self.m_model.fit_generator( parameter["train_generator"], steps_per_epoch=self.train_num, shuffle=True, epochs=int(parameter["iterate_times"]), validation_data=parameter["valid_generator"], validation_steps=self.valid_num, callbacks=self.m_callbacks) self.m_model.save("{}/all_{}.h5".format( parameter["model_fileaddr"], parameter["train_model_name"])) # return True, parameter except Exception as e: log.error("training error {}".format(e)) return False, parameter
def insert_ass_result_save(self, data): n_traininginfo_id = self.model_info_get(data["training_model_id"], data["model_version"])[1].id self.session = self.DBSession() insert_data = Tables["dataset_test_model"]( test_time=data["training_time"], accuracy_rate=data["accuracy_rate"], dataset_list=data["dataset_list"], testing_model_id=n_traininginfo_id) try: self.session.add(insert_data) self.session.commit() except Exception as e: log.error("test data insertion failed {}".format(e)) self.session.close() return False try: n_test_mode_id = self.test_model_info_get(n_traininginfo_id, self.session).id for da in data["tags_index"]: insert_da = Tables["dataset_test_index"]( F1score=da["F1score"], Gscore=da["Gscore"], precision_rate=da["precision_rate"], recall_rate=da["recall_rate"], label_id=da["label_id"], testinfo_id=n_test_mode_id) self.session.add(insert_da) self.session.commit() except Exception as e: log.error("test data index insertion failed {}".format(e)) self.session.close() return False try: for da_img in data["error_images"]: insert_da_img = Tables["dataset_test_err"]( pic_address=da_img["pic_address"], old_label=da_img["old_label"], new_label=da_img["new_label"], testinfo_id=n_test_mode_id) self.session.add(insert_da_img) self.session.commit() except Exception as e: log.error("error picture recording failed {}".format(e)) self.session.close() return False self.session.close() return True
def train_parameter(self, data): self.parameter = data #数据集下载 if not d_dataset.dowload_save_bytes( self.parameter["train_dataset_id"]): log.error("Data set download failed") return False if not dir_operator.unzip_all(): log.error("File unzip failed") return False # dataset_path = dir_operator.dataset_path_get() dataset_path = config.config["path"]["all_datasetfiles"] if not dataset_path: log.error("File address acquisition failed") return False self.parameter.update({"dataset_path": dataset_path}) if self.parameter["task_type"] is 0: self.parameter_model_version = None self.parameter.update( {"model_version": self.parameter_model_version}) n_model_file_path = dbase.model_info_get( self.parameter["train_model_id"], self.parameter["model_version"]) if n_model_file_path[2] is None: self.parameter_model_version = None self.parameter["model_version"] = self.parameter_model_version # self.parameter.update({"model_file_path":n_model_file_path[2].model_fileaddr}) self.parameter["train_model_name"] = ( self.parameter["train_model_name"] + "_{}".format( cal_result.version_get(self.parameter["train_model_id"], self.parameter["model_version"]))) dir_operator.create_model_save_dir("/{}/{}".format( os.path.split(self.parameter["new_model_path"])[1], self.parameter["train_model_name"])) elif self.parameter["task_type"] is 1: n_model_file_path = dbase.model_info_get( self.parameter["train_model_id"], self.parameter["model_version"]) if n_model_file_path[2] is None: log.error("Models cannot be optimized without training") return False self.parameter.update( {"model_file_path": n_model_file_path[1].model_fileaddr}) self.parameter["train_model_name"] = self.parameter[ "train_model_name"] + "_{}".format( cal_result.version_get(self.parameter["train_model_id"], self.parameter["model_version"])) dir_operator.create_model_save_dir("/{}/{}".format( os.path.split(self.parameter["new_model_path"])[1], self.parameter["train_model_name"])) elif self.parameter["task_type"] is 2: n_model_file_path = dbase.model_info_get( self.parameter["train_model_id"], self.parameter["model_version"]) if n_model_file_path[2] is None: log.error("Models cannot be evaluated without training") return False self.parameter.update( {"model_file_path": n_model_file_path[1].model_fileaddr}) else: log.error("Wrong type of training") return False return self.parameter
def task_allocation_kazoo(self): kazoo_client.zk_client.start() for node in kazoo_client.zk_client.get_children( config.config["kazoo"]["KAZOO_ROOT"]): self.user_node = kazoo_client.zk_client.get_children( config.config["kazoo"]["USER_NODE"].format(node)) for nn in self.user_node: try: d_data = ast.literal_eval( str(kazoo_client.zk_client.get( config.config["kazoo"]["KAZOO_NODE"].format( node, nn))[0], encoding="utf-8").replace("null", "None")) if len(d_data) is 0: break d_data["task_type"] = int(d_data["task_type"]) if d_data["training_status"] is 4 or d_data["training_status"] is 5 \ or d_data["training_status"] is 8 or d_data["training_status"] is 9 \ or d_data["training_status"] is 12 or d_data["training_status"] is 13: print("task is trained") continue else: print("-" * 10) print("task type {}".format(d_data["task_type"])) #数据拼接{ d_data.update({ "training_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) }) self.t_data = self.train_parameter(d_data) self.t_data = self.update_status( self.t_data, 4, self.user_node, nn, self.t_data["train_model_id"]) self.t_data = self.update_status( self.t_data, 4, self.user_node, nn, self.t_data["train_model_id"]) #进入训练 print("running task {}".format(node)) t_f = self.mid_train_switch(self.t_data["task_type"], self.t_data) if not t_f: self.update_status(self.t_data, 4, self.user_node, nn, self.t_data["train_model_id"]) log.error( "Abnormal end of training id {}".format(node)) print("error task {}".format(node)) break self.res = cal_result.get_result( self.t_data, self.t_data["task_type"]) self.train_result_save_switch(self.t_data["task_type"], self.res, node) print("end task true{}".format(node)) dir_operator.clear( config.config["path"]["datasetdir_path"]) dir_operator.clear( config.config["path"]["all_dataset_dir_path"]) dir_operator.create_dir( config.config["path"]["datasetdir_path"]) dir_operator.create_dir( config.config["path"]["all_dataset_dir_path"]) self.update_status(self.t_data, 4, self.user_node, nn, self.t_data["train_model_id"]) print("running end", True) break except Exception as e: log.error(e) continue