示例#1
0
 def str_dict(self, d_str):
     try:
         dic = ast.literal_eval(d_str)
         return dic
     except Exception as e:
         log.error("String conversion dictionary faile {}".format(e))
         return False
示例#2
0
 def create_first_dir(self,dataset_names,ids):
     for id in ids:
         if id in list(dataset_names.keys()):
             if not self.create_dir(config.config["path"]["firstdir_path"].format(dataset_names[id],id)):
                 log.error("file {} create faile".format(dataset_names[id]))
                 return False
     return True
示例#3
0
 def clear(self,path):
     try:
         shutil.rmtree(path)
         return True
     except Exception as e:
         log.error("del file error:{}".format(e))
         return False
示例#4
0
 def create_dir(self,path):
     if not os.path.exists(path):
         try:
             os.mkdir(path)
             return True
         except Exception as e:
             log.error(e)
             return False
     return True
 def test_model_info_get(self, test_model_id, session):
     try:
         test_model_info = session.query(
             Tables["dataset_test_model"]).filter_by(
                 testing_model_id=test_model_id).all()
         return test_model_info[len(test_model_info) - 1]
     except Exception as e:
         log.error("test 0model info get fail {}".format(e))
         return False
示例#6
0
 def image_save(self,dataset_names,id,response,tag,image):
     try:
         with open(config.config["path"]["image_path"].format(dataset_names[id],id,tag,image),"wb+") as f:
             f.write(base64.b64decode(response.json()["data"][tag][image]))
             f.close()
     except Exception as e:
         log.error("image:{} svase error  {}".format(image,e))
         return False
     return True
示例#7
0
 def dataset_path_get(self):
     try:
         self.cur_path = os.path.dirname(os.path.realpath(__file__))
         self.cur_path = os.path.join(self.cur_path,config.config["path"]["datasetdir_path"].replace("./",""))
     
         return [os.path.join(self.cur_path,dir) for dir in os.listdir(config.config["path"]["datasetdir_path"]) if not ".zip" in dir]  
     except Exception as e:
         log.error("dataset path get fail {}".format(e))
         return False 
示例#8
0
 def mid_train_switch(self, task_type, data):
     if task_type == 0:
         return model_train.mid_model_init_train(data)
     elif task_type == 1:
         return model_train.mid_model_opt_train(data)
     elif task_type == 2:
         return model_train.mid_model_ass_train(data)
     else:
         log.error("fail task_type")
         return False
示例#9
0
 def sendjsondata(self, params, key):
     try:
         parmas_message = json.dumps(params)
         producer = self.producer
         v = parmas_message.encode('utf-8')
         producer.send(self.kafkatopic, value=v, key=key.encode("utf-8"))
         producer.close()
         return True
     except KafkaError as e:
         log.error("kafka input data failed {}".format(e))
         return False
 def dataset_info_get(self, dataset_ids):
     data = []
     self.session = self.DBSession()
     for id in dataset_ids:
         try:
             data.append(self.session.query(Tables["dataset_info"]).get(id))
         except Exception as e:
             log.error("The dataset query failed {}".format(e))
             return False
     self.session.close()
     return {id: name.dataset_name for id, name in zip(dataset_ids, data)}
示例#11
0
    def create_model_save_dir(self,path):

        if not os.path.exists(path):
            try:
                os.makedirs(path)
                return True
            except Exception as e:
                log.error("Model file save file folder creation failed {}".format(e))
                return False
        else:
            print("Folder already exists")
            return True
示例#12
0
 def dowload_save_bytes(self, dataset_ids):
     for id in dataset_ids:
         if not self.reponse_get_bytes(id):
             log.error("dataset:{} dowload fail".format(id))
             return False
         with open(config.config["path"]["datasetdir_paths"].format(id),
                   "ab") as f:
             # f.write(self.response_bytes.content)
             for chunk in self.response_bytes.iter_content(chunk_size=1024):
                 if chunk:
                     f.write(chunk)
     return True
    def mid_model_init_train(self, parameter):
        if config.config["gpu_control"]["open"] == "1":
            try:
                self.memory_size = config.config["memory_size"]["mm_size"]
                self.gpu_init(self.memory_size)
            except Exception as e:
                log.error("gpu error {}".format(e))
        #数据增强器  配置文件
        self.train_datagen = ImageDataGenerator(
            rescale=1. / 255,
            rotation_range=40,  # 随机旋转角度的范围
            width_shift_range=0.2,  # 随机转换图片宽度的范围
            height_shift_range=0.2,  # 随机转换图片高度的范围
            shear_range=0.2,  # 随机剪切转换比例
            zoom_range=0.2,  # 随机放缩比例
            horizontal_flip=True,  # 开启水平翻转
            fill_mode='nearest'  # 填充策略
        )
        self.valid_datagen = ImageDataGenerator(rescale=1. / 255)
        image_height, image_width = self.mid_height_width(
            parameter)  #320x480 h_w
        self.model_path = os.path.join(parameter["new_model_path"],
                                       parameter["train_model_name"])
        self.result_path = os.path.join(self.model_path, "model_result")
        parameter.update({"model_fileaddr": self.result_path})
        dir_operator.create_model_save_dir(self.result_path)
        self.train_dataset = ""
        self.vaild_dataset = ""
        #获取训练集的图片增强生成器
        self.train_generator = train_datagen.flow_from_directory(
            self.train_dataset,
            target_size=(image_height, image_width),
            classes=parameter["target"].keys(),  #数据库数据加载--》、
            batch_size=parameter["batch_size"],  #--》配置文件
            shuffle=True,
            class_mode=config.config["parameter"]["class_mode"])  #写死
        #获取测试集的图片增强生成器
        self.valid_generator = valid_datagen.flow_from_directory(
            self.vaild_dataset,
            target_size=(image_height, image_width),
            classes=parameter["target"].keys(),
            batch_size=int(
                config.config["parameter"]["batch_size"]),  #每批次训练的大小
            shuffle=True,
            class_mode=config.config["parameter"]["class_mode"])
        data.update({
            "train_generator": self.train_generator,
            "valid_generator": self.valid_generator
        })

        self.mid_package_program(parameter)
        return True
 def updata_train_mdoel_status(self, model_id, status):
     self.session = self.DBSession()
     try:
         model_info = self.session.query(
             Tables["dataset_model_info"]).filter_by(id=model_id).all()
         model_info[0].training_status = status
         self.session.commit()
         self.session.close()
         return True
     except Exception as e:
         log.error("model:{} status change failed{}".format(model_id, e))
         self.session.close()
         return False
示例#15
0
 def mid_merge_dataset(self,path_list): 
     try: 
         self.create_dir(config.config["path"]["all_dataset_dir_path"])
         for dateset_path in path_list:
             for tag in os.listdir(dateset_path):
                 print(tag)
                 self.create_dir(os.path.join(config.config["path"]["all_dataset_dir_path"],tag))
                 for image_file in os.listdir(os.path.join(dateset_path,tag)):
                     print(os.path.join(dateset_path,tag,image_file))
                     shutil.copyfile(os.path.join(dateset_path,tag,image_file),os.path.join(config.config["path"]["all_dataset_dir_path"],tag,image_file))
         return True
     except Exception as e:
         log.error("dataset merge error {}".format(e))
         return False
示例#16
0
 def train_result_save_switch(self, task_type, data, node):
     if task_type is 0 or task_type is 1:
         if not dbase.insert_init_opt_result_save(data):
             log.error("init opt result save error {}".format(node))
             self.update_status(data, 4, node, data["train_model_id"])
             return False
         return True
     elif task_type is 2:
         if not dbase.insert_ass_result_save(data):
             log.error("ass result save error {}".format(node))
             self.update_status(data, 4, node, data["train_model_id"])
             return False
         return True
     else:
         return False
 def mid_create_model(self, parameter):
     try:
         image_height, image_width = self.mid_height_width(parameter)
         dd = [False, None, (image_height, image_width, 3)]
         self.c_conv_base = self.mid_reflect("keras.applications.resnet",
                                             "ResNet50", dd)
         # self.c_conv_base = keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet',input_shape=(image_height, image_width,3))
         self.x = self.c_conv_base.output
         self.x = Flatten()(self.x)
         self.x = Dense(64, activation="relu")(self.x)
         self.x = Dense(len(parameter["target"]),
                        activation="softmax")(self.x)
         self.c_model = Model(inputs=self.c_conv_base.input, outputs=self.x)
         return True, self.c_model
     except Exception as e:
         log.error("create model error {}".format(e))
         return False, None
示例#18
0
 def reponse_get_bytes(self, id):
     self.headers = {
         'content-type': 'application/json',
         "authorization": "Bearer wYlvTrIhhvMpfBC4cW1Wdef4FoNgJt"
     }
     try:
         self.response_bytes = requests.get(
             config.config["url"]["dataset_url"].format(id),
             headers=self.headers,
             stream=True)
         if self.response_bytes.status_code > 200 and self.response_bytes.status_code < 200:
             log.error("dataset_{} response faile".format(id))
             return False
         return True
     except Exception as e:
         log.error("requests error %s" % e)
         return False
    def model_info_get(self, model_id, model_version):
        self.session = self.DBSession()
        now_mdoel_info = self.session.query(
            Tables["dataset_model_info"]).get(model_id)
        if now_mdoel_info.model_status is 1:
            last_version = None
        if model_version is None:
            try:
                last_version = self.session.query(
                    Tables["dataset_train_model"]).filter_by(
                        training_model_id=model_id).order_by(
                            Tables["dataset_train_model"].id).all()  #
                if len(last_version) is 0:
                    last_version = None
                else:
                    last_version = last_version[len(last_version) - 1]
                n_version_info = None
                self.session.close()

                return now_mdoel_info, n_version_info, last_version
            except Exception as e:
                self.session.close()
                log.error("{} model:{} info get error".format(e, model_id))
                return False
        else:
            try:
                last_version = self.session.query(
                    Tables["dataset_train_model"]).filter_by(
                        training_model_id=model_id).order_by(
                            Tables["dataset_train_model"].id).all()  #
                n_version_info = self.session.query(
                    Tables["dataset_train_model"]).filter_by(
                        training_model_id=model_id,
                        model_version=model_version).all()
                self.session.close()
                return now_mdoel_info, n_version_info[0], last_version[
                    len(last_version) - 1]  #name,address,last_version
            except Exception as e:
                log.error("{} model:{} info get error".format(e, model_id))
                self.session.close()
                return False
示例#20
0
 def un_zip(self,zip_file_name):
     self.cur_path = os.path.dirname(os.path.realpath(__file__))
     d_path = config.config["path"]["datasetdir_path"].replace("./","")
     self.zip_path = os.path.join(self.cur_path,d_path,zip_file_name)
     self.tar_path = os.path.join(self.cur_path,d_path)
     z = zipfile.ZipFile(self.zip_path,"r")
     for k in z.infolist():
         try:
             save_path = (self.tar_path+"/"+k.filename).replace("\\","/")
             if not os.path.exists(save_path):
                 if not "." in save_path:
                     os.mkdir(save_path)
             if "." in save_path:
                 image_file = z.read(k)
                 with open(save_path,"wb") as f:
                     f.write(image_file)
         except Exception as e:
             log.error("image unzip fail {}".format(e))
             return False 
     z.close()
     return True
    def insert_init_opt_result_save(self, data):
        # try:
        #     n_traininginfo_id = self.model_info_get(data["training_model_id"],None)[2].id
        # except Exception as e:
        #     log.error("model info get error {}".format(e))
        #     return False
        self.session = self.DBSession()
        insert_data = Tables["dataset_train_model"](
            training_time=data["training_time"],
            accuracy_rate=data["accuracy_rate"],
            model_version=data["model_version"],
            model_fileaddr=data["model_fileaddr"],
            dataset_list=data["dataset_list"],
            training_model_id=data["training_model_id"])
        try:
            self.session.add(insert_data)
            self.session.commit()

            try:
                n_traininginfo_id = self.model_info_get(
                    data["training_model_id"], None)[2].id
            except Exception as e:
                log.error("model info get error {}".format(e))
                return False
            for da in data["tags_index"]:
                insert_da = Tables["dataset_train_index"](
                    F1score=da["F1score"],
                    Gscore=da["Gscore"],
                    precision_rate=da["precision_rate"],
                    recall_rate=da["recall_rate"],
                    label_id=da["label_id"],
                    traininginfo_id=n_traininginfo_id)
                self.session.add(insert_da)
            self.session.commit()
            self.session.close()
            return True
        except Exception as e:
            log.error("data insert error {}".format(e))
            self.session.close()
            return False
 def mid_package_program(self, parameter):
     try:
         image_height, image_width = self.mid_height_width(parameter)
         self.flag, self.m_model = self.mid_create_model(
             parameter)  #这个地方后面要修改
         print("-" * 10)
         self.m_loss = self.mid_reflect(parameter["loss_value"],
                                        "SparseCategoricalCrossentropy",
                                        [])()  #损失函数这个返回的是类加()获取对象
         self.m_optimizer = self.mid_reflect(
             parameter["optimizer"], "RMSprop",
             [float(config.config["parameter"]["optimizer_lr"])
              ])  #优化器也是返回了类但是不要对象
         self.m_model.compile(loss=self.m_loss,
                              optimizer=self.m_optimizer,
                              metrics=['accuracy'])
         self.train_num = parameter["train_generator"].samples // int(
             config.config["parameter"]["batch_size"])
         self.valid_num = parameter["valid_generator"].samples // int(
             config.config["parameter"]["batch_size"])
         self.model_path = os.path.join(parameter["new_model_path"],
                                        parameter["train_model_name"])
         self.result_path = os.path.join(self.model_path, "model_result")
         self.m_callbacks = self.mid_model_callback(self.model_path,
                                                    self.result_path)
         self.m_model.fit_generator(
             parameter["train_generator"],
             steps_per_epoch=self.train_num,
             shuffle=True,
             epochs=int(parameter["iterate_times"]),
             validation_data=parameter["valid_generator"],
             validation_steps=self.valid_num,
             callbacks=self.m_callbacks)
         self.m_model.save("{}/all_{}.h5".format(
             parameter["model_fileaddr"], parameter["train_model_name"]))  #
         return True, parameter
     except Exception as e:
         log.error("training error {}".format(e))
         return False, parameter
    def insert_ass_result_save(self, data):
        n_traininginfo_id = self.model_info_get(data["training_model_id"],
                                                data["model_version"])[1].id

        self.session = self.DBSession()
        insert_data = Tables["dataset_test_model"](
            test_time=data["training_time"],
            accuracy_rate=data["accuracy_rate"],
            dataset_list=data["dataset_list"],
            testing_model_id=n_traininginfo_id)
        try:
            self.session.add(insert_data)
            self.session.commit()
        except Exception as e:
            log.error("test data insertion failed {}".format(e))
            self.session.close()
            return False
        try:
            n_test_mode_id = self.test_model_info_get(n_traininginfo_id,
                                                      self.session).id
            for da in data["tags_index"]:
                insert_da = Tables["dataset_test_index"](
                    F1score=da["F1score"],
                    Gscore=da["Gscore"],
                    precision_rate=da["precision_rate"],
                    recall_rate=da["recall_rate"],
                    label_id=da["label_id"],
                    testinfo_id=n_test_mode_id)
                self.session.add(insert_da)
            self.session.commit()
        except Exception as e:
            log.error("test data index insertion failed {}".format(e))
            self.session.close()
            return False
        try:
            for da_img in data["error_images"]:
                insert_da_img = Tables["dataset_test_err"](
                    pic_address=da_img["pic_address"],
                    old_label=da_img["old_label"],
                    new_label=da_img["new_label"],
                    testinfo_id=n_test_mode_id)
                self.session.add(insert_da_img)
            self.session.commit()
        except Exception as e:
            log.error("error picture recording failed {}".format(e))
            self.session.close()
            return False
        self.session.close()
        return True
示例#24
0
 def train_parameter(self, data):
     self.parameter = data
     #数据集下载
     if not d_dataset.dowload_save_bytes(
             self.parameter["train_dataset_id"]):
         log.error("Data set download failed")
         return False
     if not dir_operator.unzip_all():
         log.error("File unzip failed")
         return False
     # dataset_path = dir_operator.dataset_path_get()
     dataset_path = config.config["path"]["all_datasetfiles"]
     if not dataset_path:
         log.error("File address acquisition failed")
         return False
     self.parameter.update({"dataset_path": dataset_path})
     if self.parameter["task_type"] is 0:
         self.parameter_model_version = None
         self.parameter.update(
             {"model_version": self.parameter_model_version})
         n_model_file_path = dbase.model_info_get(
             self.parameter["train_model_id"],
             self.parameter["model_version"])
         if n_model_file_path[2] is None:
             self.parameter_model_version = None
         self.parameter["model_version"] = self.parameter_model_version
         # self.parameter.update({"model_file_path":n_model_file_path[2].model_fileaddr})
         self.parameter["train_model_name"] = (
             self.parameter["train_model_name"] + "_{}".format(
                 cal_result.version_get(self.parameter["train_model_id"],
                                        self.parameter["model_version"])))
         dir_operator.create_model_save_dir("/{}/{}".format(
             os.path.split(self.parameter["new_model_path"])[1],
             self.parameter["train_model_name"]))
     elif self.parameter["task_type"] is 1:
         n_model_file_path = dbase.model_info_get(
             self.parameter["train_model_id"],
             self.parameter["model_version"])
         if n_model_file_path[2] is None:
             log.error("Models cannot be optimized without training")
             return False
         self.parameter.update(
             {"model_file_path": n_model_file_path[1].model_fileaddr})
         self.parameter["train_model_name"] = self.parameter[
             "train_model_name"] + "_{}".format(
                 cal_result.version_get(self.parameter["train_model_id"],
                                        self.parameter["model_version"]))
         dir_operator.create_model_save_dir("/{}/{}".format(
             os.path.split(self.parameter["new_model_path"])[1],
             self.parameter["train_model_name"]))
     elif self.parameter["task_type"] is 2:
         n_model_file_path = dbase.model_info_get(
             self.parameter["train_model_id"],
             self.parameter["model_version"])
         if n_model_file_path[2] is None:
             log.error("Models cannot be evaluated without training")
             return False
         self.parameter.update(
             {"model_file_path": n_model_file_path[1].model_fileaddr})
     else:
         log.error("Wrong type of training")
         return False
     return self.parameter
示例#25
0
 def task_allocation_kazoo(self):
     kazoo_client.zk_client.start()
     for node in kazoo_client.zk_client.get_children(
             config.config["kazoo"]["KAZOO_ROOT"]):
         self.user_node = kazoo_client.zk_client.get_children(
             config.config["kazoo"]["USER_NODE"].format(node))
         for nn in self.user_node:
             try:
                 d_data = ast.literal_eval(
                     str(kazoo_client.zk_client.get(
                         config.config["kazoo"]["KAZOO_NODE"].format(
                             node, nn))[0],
                         encoding="utf-8").replace("null", "None"))
                 if len(d_data) is 0:
                     break
                 d_data["task_type"] = int(d_data["task_type"])
                 if d_data["training_status"] is 4 or d_data["training_status"] is 5 \
                 or d_data["training_status"] is 8 or d_data["training_status"] is 9 \
                 or d_data["training_status"] is 12 or d_data["training_status"] is 13:
                     print("task is trained")
                     continue
                 else:
                     print("-" * 10)
                     print("task type {}".format(d_data["task_type"]))
                     #数据拼接{
                     d_data.update({
                         "training_time":
                         time.strftime("%Y-%m-%d %H:%M:%S",
                                       time.localtime())
                     })
                     self.t_data = self.train_parameter(d_data)
                     self.t_data = self.update_status(
                         self.t_data, 4, self.user_node, nn,
                         self.t_data["train_model_id"])
                     self.t_data = self.update_status(
                         self.t_data, 4, self.user_node, nn,
                         self.t_data["train_model_id"])
                     #进入训练
                     print("running task {}".format(node))
                     t_f = self.mid_train_switch(self.t_data["task_type"],
                                                 self.t_data)
                     if not t_f:
                         self.update_status(self.t_data, 4, self.user_node,
                                            nn,
                                            self.t_data["train_model_id"])
                         log.error(
                             "Abnormal end of training id {}".format(node))
                         print("error task {}".format(node))
                         break
                     self.res = cal_result.get_result(
                         self.t_data, self.t_data["task_type"])
                     self.train_result_save_switch(self.t_data["task_type"],
                                                   self.res, node)
                     print("end  task true{}".format(node))
                     dir_operator.clear(
                         config.config["path"]["datasetdir_path"])
                     dir_operator.clear(
                         config.config["path"]["all_dataset_dir_path"])
                     dir_operator.create_dir(
                         config.config["path"]["datasetdir_path"])
                     dir_operator.create_dir(
                         config.config["path"]["all_dataset_dir_path"])
                     self.update_status(self.t_data, 4, self.user_node, nn,
                                        self.t_data["train_model_id"])
                     print("running end", True)
                     break
             except Exception as e:
                 log.error(e)
                 continue