Пример #1
0
    def __first_augment_data(self, augmentation_num_list):
        """
        :param augmentation_num_list: 一次增强图像数目list
        :return:
        """
        bz_log.info("开始第一次数据增强")
        replace = False
        if len(self.img_list) < np.max(augmentation_num_list):
            replace = True

        for i, (augment_name,
                augment_fn) in enumerate(self.augment_fn_dict.items()):
            num_index = np.random.choice(range(len(self.img_list)),
                                         augmentation_num_list[i],
                                         replace=replace)
            for j in range(augmentation_num_list[i]):
                img_path = self.img_list[num_index][j]
                label = self.label_list[num_index][j]

                origin_img_name, img_extension = bz_path.get_file_name(
                    img_path, True)

                img_name = origin_img_name + '_' + augment_name + '_' + self.__get_timestamp(
                )
                augmentation_img_path, augmentation_label = \
                    self.__create_data_fn(
                        img_path, label, augment_fn, img_name, self.out_file_extension_list, True)
                self.augmentation_img_list = np.append(
                    self.augmentation_img_list, augmentation_img_path)
                self.augmentation_label_list = np.append(
                    self.augmentation_label_list, augmentation_label)

        bz_log.info("完成第一次数据增强")
Пример #2
0
    def _multiply_ratio_augmentation(self):
        bz_log.info("进行多次数据增强")
        img_num = len(self.img_list)
        augment_mode = len(self.augment_fn_dict)
        first_augmentation_ratio = np.minimum(self.augmentation_ratio,
                                              self.augmentation_split_ratio)
        bz_log.info("获得第一次增强比率%f", first_augmentation_ratio)
        num = np.int32((first_augmentation_ratio - 1) * img_num / augment_mode)
        remainder = (first_augmentation_ratio - 1) * img_num % augment_mode
        num_list = np.ones(shape=augment_mode, dtype=np.int32) * num
        num_list[0] += remainder
        if self.augmentation_ratio <= self.augmentation_split_ratio:
            # 一次增强
            self.__first_augment_data(num_list)
        else:
            # 一次增强
            self.__first_augment_data(num_list)
            # 二次增强
            second_augmentation_num = np.int64(
                (self.augmentation_ratio - self.augmentation_split_ratio) *
                img_num)

            self.__second_augment_data(second_augmentation_num)

        if self.is_enhance_image_only:
            self.augmentation_img_list = self._one_ratio_augmentation()
            print('数据增强完成!')
            return self.augmentation_img_list
        # 增加原始图像
        self.augmentation_img_list, self.augmentation_label_list = self._one_ratio_augmentation(
        )

        print('数据增强完成!')
        return self.augmentation_img_list, self.augmentation_label_list
Пример #3
0
    def __second_augment_data(self, augmentation_num):
        """
        :param augmentation_num: 二次增强图像数目
        :return:返回一次增强后图像和标签的list
        """
        bz_log.info("开始进行第二次增强")
        num_index = np.random.choice(range(len(self.augmentation_img_list)),
                                     augmentation_num,
                                     replace=True)
        for i in range(augmentation_num):
            img_path = self.augmentation_img_list[num_index][i]
            label = self.augmentation_label_list[num_index][i]
            random = np.random.randint(0, 5)
            augment_name = list(self.augment_fn_dict.keys())[random]

            origin_img_name, img_extension = bz_path.get_file_name(
                img_path, True)
            img_name = origin_img_name + '_' + augment_name + '_' + self.__get_timestamp(
            )
            augmentation_img_path, augmentation_label = self.__create_data_fn(
                img_path, label, self.augment_fn_dict[augment_name], img_name,
                self.out_file_extension_list, True)
            self.augmentation_img_list = np.append(self.augmentation_img_list,
                                                   augmentation_img_path)
            self.augmentation_label_list = np.append(
                self.augmentation_label_list, augmentation_label)
        bz_log.info("完成第二次数据增强")
Пример #4
0
 def __check_info_txt(self):
     if not os.path.exists(self.generate_data_folder):
         os.makedirs(self.generate_data_folder)
     if not os.path.exists(self.generate_data_folder + os.sep +
                           'img_list.npy') or not os.path.exists(
                               self.generate_data_folder + os.sep +
                               'label_list.npy') or not os.path.exists(
                                   self.generate_data_folder + os.sep +
                                   'result_info.txt'):
         bz_log.info('信息文件丢失,正在重新生成数据')
         shutil.rmtree(self.generate_data_folder)
         self.is_repeat_data = False
     if os.path.exists(self.generate_data_folder + os.sep +
                       'result_info.txt'):
         if os.path.exists(self.generate_data_folder + os.sep +
                           'result_info_old.txt'):
             os.remove(self.generate_data_folder + os.sep +
                       'result_info_old.txt')
         os.rename(
             self.generate_data_folder + os.sep + 'result_info.txt',
             self.generate_data_folder + os.sep + 'result_info_old.txt')
         self.__write_result_info()
         self.is_repeat_data = filecmp.cmp(
             self.generate_data_folder + os.sep + 'result_info.txt',
             self.generate_data_folder + os.sep + 'result_info_old.txt')
         if not self.is_repeat_data:
             shutil.rmtree(self.generate_data_folder)
Пример #5
0
    def create_data(self):
        # crx判断是否进行数据增强还是重用上次增强后的数据
        self.__check_info_txt()
        img_list, label_list, augmentation_ratios, generate_data_folders = self.__get_augment_params(
        )

        if self.is_repeat_data == False:
            bz_log.info('生成均衡数据')
            img_list, label_list = self.__augment_data(
                img_list=img_list,
                label_list=label_list,
                augmentation_ratios=augmentation_ratios,
                channel=self.channel,
                generate_data_folders=generate_data_folders)
        else:
            # crx保留
            bz_log.info('进入数据重用')
            img_list_npy_path = self.generate_data_folder + os.sep + 'img_list.npy'
            label_list_npy_path = self.generate_data_folder + os.sep + 'label_list.npy'
            img_list, label_list = np.load(img_list_npy_path), np.load(
                label_list_npy_path)

            shuffle_indices = np.arange(len(img_list))
            np.random.shuffle(shuffle_indices)

            img_list = img_list[shuffle_indices]
            label_list = label_list[shuffle_indices]

        return img_list.T.flatten(), label_list.T.flatten(
        ), self.is_repeat_data
Пример #6
0
    def augmentation(self):
        # crx只做一次增强
        bz_log.info("进行数据增强")
        for i in range(len(self.img_list)):
            img_name, img_extension = bz_path.get_file_name(
                self.img_list[i], True)
            for j in range(int(self.augmentation_ratio)):
                random = np.random.randint(0, 5)
                augment_name = list(self.augment_fn_dict.keys())[random]
                img_path, label = self.__create_data_fn(
                    self.img_list[i], self.label_list[i],
                    self.augment_fn_dict[augment_name], img_name,
                    self.out_file_extension_list, False)
                self.augmentation_img_list = np.append(
                    self.augmentation_img_list, img_path)
                self.augmentation_label_list = np.append(
                    self.augmentation_label_list, label)
        np.save(self.data_list_npy_path + '/img.npy',
                self.augmentation_img_list)
        if self.is_enhance_image_only:
            shutil.rmtree(self.augmentation_label_dir)
            return self.augmentation_img_list
        np.save(self.data_list_npy_path + '/label.npy',
                self.augmentation_label_list)

        print('数据增强完成!')
        return self.augmentation_img_list, self.augmentation_label_list
Пример #7
0
def make_label_func(train_folder, eval_folder, processed_train_folder,
                    processed_eval_folder, class_num):
    # 进行图片,label的处理,真正喂入神经网络的数据路径
    bz_log.info("生成标签...")
    generate_label_func = generate_yolo_dataset.GenerateYoloLabel()
    file_in_one_tfrecord = 50  #用来规定单个tfrecord中的文件数

    #添加训练,验证的所有图片,标签路径
    for folder in [train_folder, eval_folder]:
        if 'train' in folder:
            train_img_list, train_label_list = get_record_list(folder)
        elif 'val' in folder:
            eval_img_list, eval_label_list = get_record_list(folder)

    if not os.path.exists(processed_train_folder):
        os.makedirs(processed_train_folder)
    if not os.path.exists(processed_eval_folder):
        os.makedirs(processed_eval_folder)

    #把balanced_train所有文件封装到processed_train_folder下的多个tfrecords
    pack_func(train_img_list,
              train_label_list,
              generate_label_func,
              class_num,
              processed_train_folder,
              flag='train_',
              file_in_one_tfrecord=file_in_one_tfrecord)
    # 把balanced_eval所有文件封装到processed_eval_folder下的多个tfrecords
    pack_func(eval_img_list,
              eval_label_list,
              generate_label_func,
              class_num,
              saved_folder=processed_eval_folder,
              flag='eval_',
              file_in_one_tfrecord=file_in_one_tfrecord)
Пример #8
0
    def __augment_data(self, img_list, label_list, augmentation_ratios,
                       channel, generate_data_folders):
        '''
        调用数据增强模块将所有类别按照相应的增强比例增强
        :param base_folder: 原始数据文件夹路径
        :param sub_folders: 原始数据文件夹下各类别文件夹
        :param generate_data_folder: 增强文件夹路径
        :param augmentation_nums: 各类别进行图像增强后的数量
        :param augmentation_ratios: 各类别进行图像增强的比例
        :param generate_data_folders: 增强文件夹下各类别文件夹
        :param max_augment_num: 最大增强数量
        :return:
        '''
        bz_log.info("调用数据增强模块将所有类别按照相应的增强比例增强")
        generate_img_list = []
        generate_label_list = []
        for i in range(len(self.sub_folders)):
            augmentation_data = python_data_augmentation.DataAugmentation(
                img_list=img_list[i],
                label_list=label_list[i],
                augmentation_ratio=augmentation_ratios[i],
                channel=channel,
                task=self.task,
                generate_data_folder=generate_data_folders[i],
                out_file_extension_list=self.out_file_extension_list)
            aug_single_class_img_list, aug_single_class_label_list = augmentation_data.augment_data(
            )

            shuffle_indices = np.arange(len(aug_single_class_img_list))
            np.random.shuffle(shuffle_indices)
            shuffle_aug_single_class_img_list = aug_single_class_img_list[
                shuffle_indices]
            shuffle_aug_single_class_label_list = aug_single_class_label_list[
                shuffle_indices]

            generate_img_list.append(
                shuffle_aug_single_class_img_list.tolist()
                [:self.max_augment_num])
            generate_label_list.append(
                shuffle_aug_single_class_label_list.tolist()
                [:self.max_augment_num])
            if self.augmentation_nums[i] != self.max_augment_num:
                for index in range(self.max_augment_num,
                                   len(aug_single_class_img_list)):
                    os.remove(shuffle_aug_single_class_label_list[index])
                    os.remove(shuffle_aug_single_class_img_list[index])
        generate_img_list = np.array(generate_img_list)
        generate_label_list = np.array(generate_label_list)
        img_list_npy_path = self.generate_data_folder + os.sep + 'img_list'
        label_list_npy_path = self.generate_data_folder + os.sep + 'label_list'
        np.save(img_list_npy_path, generate_img_list)
        np.save(label_list_npy_path, generate_label_list)
        self.all_img_nums = generate_img_list.size
        self.__write_result_info()
        return generate_img_list.reshape(
            -1, self.max_augment_num), generate_label_list.reshape(
                -1, self.max_augment_num)
Пример #9
0
 def __init_value_judgment(self):
     if os.path.exists(self.best_model_info_path):
         with open(self.best_model_info_path, 'r') as f:
             best_model_info_dict = json.load(f)
         self.value_judgment = float(best_model_info_dict['loss'])
         bz_log.info("打印加载的value_judgment%f", self.value_judgment)
     else:
         self.value_judgment = 1000
         bz_log.info("不存在")
Пример #10
0
    def __write_result_info(self, ):
        bz_log.info("记录当前训练的数据信息")
        with open(self.generate_data_folder + os.sep + 'result_info.txt',
                  'w') as f:
            f.write('base_folder=' + str(self.base_folder) + '\n')
            f.write('generate_data_folder=' + str(self.generate_data_folder) +
                    '\n')
            f.write('out_file_extension_list=' +
                    str(self.out_file_extension_list) + '\n')

            f.write('src_max_augment_num=' + str(self.src_max_augment_num) +
                    '\n')
            f.write('task=' + self.task + '\n')
            f.write('channel=' + str(self.channel) + '\n')
Пример #11
0
def get_subfolder_path(folder, ret_full_path=True, is_recursion=True):
    """
    获取
    :param folder:父文件夹
    :param ret_full_path:是否返回全路径
    :param is_recursion:是否递归所有文件夹,如果是会将子文件下包含的子文件夹也遍历
    :return:
    """
    '''
        作用:
            获取folder中所有子文件夹的路径
        参数:
            ret_full_path: 是否返回全路径,默认返回子文件夹全路径
    '''

    if not (ret_full_path or ret_full_path == False):
        bz_log.error('输入参数只能是True或者False%s', ret_full_path)
        raise ValueError('输入参数只能是True或者False')
    if not (os.path.isdir(folder)):
        bz_log.error('输入参数必须是目录或者文件夹%s', folder)
        raise ValueError('输入参数必须是目录或者文件夹')
    default_separation_line = '/'
    if (platform.system() == 'Windows'):
        default_separation_line = '\\'
        if folder[-1] != '\\':
            folder = folder + '\\'
    elif (platform.system() == 'Linux'):
        if folder[-1] != '/':
            folder = folder + '/'
    else:
        bz_log.error('目前只支持Windows 系统和Linux系统!')
        raise ValueError('目前只支持Windows 系统和Linux系统!')
    result = []
    if is_recursion:
        for root, dirs, files in os.walk(folder):
            for d in dirs:
                if ret_full_path:
                    result.append(
                        os.path.join(root, d) + default_separation_line)
                else:
                    result.append(d)
    else:
        bz_log.info("根据系统生成文件路径%s", folder)
        result = os.listdir(folder)
        if ret_full_path:
            result = [folder + folder_dir for folder_dir in result]
    return result
Пример #12
0
 def __init__(self,
              class_num=2,
              model_dir='./model_dir',
              regularizer_scale=(0, 0),
              optimizer_fn=tf.train.AdamOptimizer,
              background_and_foreground_loss_weight=(0.45, 0.55),
              class_loss_weight=(1, 1, 1),
              max_img_outputs=6,
              learning_rate=1e-3,
              tensor_to_log={'probablities': 'softmax'},
              assessment_list=['accuracy', 'iou', 'recall', 'precision', 'auc']):
     """
     :param model_dir: 模型路径
     :param regularizer_scale: l1, l2正则系数, tuple或者list格式
     :param learning_rate: 学习率
     :param class_num: 分类个数
     :param optimizer_fn: 优化器函数
     :param background_and_foreground_loss_weight: 前景背景权重
     :param class_loss_weight: 像素类别权重
     :param max_outputs: 图像最大输出个数
         :param assessment_list: 评价模型的list
         如果只使用前后背景的loss
     """
     self.class_num = class_num
     self.regularizer_scale = regularizer_scale
     self.optimizer_fn = optimizer_fn
     self.background_and_foreground_loss_weight = \
         background_and_foreground_loss_weight
     self.class_loss_weight = class_loss_weight
     self.max_img_outputs = max_img_outputs
     self.learning_rate = learning_rate
     self.tensor_to_log = tensor_to_log
     self.assessment_list = assessment_list
     os.environ[
         "CUDA_VISIBLE_DEVICES"] = '0'  # 指定第一块GPU可用
     session_config = tf.ConfigProto()
     # session_config.gpu_options.per_process_gpu_memory_fraction = 0.6 # 程序最多只能占用指定gpu50%的显存
     session_config.gpu_options.allow_growth = True  # 程序按需申请内存
     run_config = tf.estimator.RunConfig(keep_checkpoint_max=5, session_config=session_config)
     bz_log.info("begin tf init")
     super().__init__(config=run_config,
                      model_dir=model_dir,
                      model_fn=self.__model_fn)
     bz_log.info("end tf init")
     self.__check_params()
Пример #13
0
    def __create_augmentation_data_dir(self):

        self.data_list_npy_path = self.generate_data_folder + '/augmentation_data_list_npy/'
        self.augmentation_img_dir = self.generate_data_folder + '/augmentation_img/'
        self.augmentation_label_dir = self.generate_data_folder + '/augmentation_label/'

        if os.path.exists(self.data_list_npy_path):
            shutil.rmtree(self.data_list_npy_path)
        os.mkdir(self.data_list_npy_path)

        if os.path.exists(self.augmentation_img_dir):
            shutil.rmtree(self.augmentation_img_dir)
        os.mkdir(self.augmentation_img_dir)

        if os.path.exists(self.augmentation_label_dir):
            shutil.rmtree(self.augmentation_label_dir)
        os.mkdir(self.augmentation_label_dir)
        bz_log.info("完成数据增强文件路径的创建")
Пример #14
0
def make_label_func(out_path,train_folder,eval_folder,processed_train_folder,processed_eval_folder,class_num):
    # 进行图片,label的处理,真正喂入神经网络的数据路径
    bz_log.info("生成训练对应的标签格式")
    generate_label_func = generate_yolo_dataset.GenerateYoloLabel()

    for folder in [train_folder,eval_folder]:
        for subfolder in os.listdir(folder):
            if os.path.isdir(folder+os.sep+subfolder):
                print(subfolder)
                #balanced 增强后的数据路径
                sub_img_folder=folder+os.sep+subfolder+os.sep+'augmentation_img'
                sub_label_folder=folder+os.sep+subfolder+os.sep+'augmentation_label'
                #创建一个新的文件夹,用来存储计算后的标签
                if 'train' in folder:
                    img_calced_folder=processed_train_folder+os.sep+subfolder+'/img/'
                    label_calced_folder=processed_train_folder+os.sep+subfolder+'/label/'
                else:
                    img_calced_folder = processed_eval_folder + os.sep + subfolder + '/img/'
                    label_calced_folder = processed_eval_folder + os.sep + subfolder + '/label/'
                if os.path.exists(label_calced_folder):
                    shutil.rmtree(label_calced_folder)
                os.makedirs(label_calced_folder)
                if os.path.exists(img_calced_folder):
                    shutil.rmtree(img_calced_folder)
                os.makedirs(img_calced_folder)

                for k,img_p in enumerate(os.listdir(sub_img_folder)):
                    #对 一对img,label进行计算
                    img_name=img_p.split('.')[0]
                    cur_img_file=sub_img_folder+os.sep+img_p
                    cur_label_file=sub_label_folder+os.sep+img_name+'.npy'

                    bz_log.info('开始处理%d,张图%s', k, img_p)
                    print('开始处理' + str(k)+ '张图,',img_p )
                    bz_log.info("解析数据")

                    #bboxes:[center_y, center_x, height, width]
                    (img,bboxes_),cls_= parse_simple_data(cur_img_file,cur_label_file,config.img_shape[: 2])

                    # bboxes:[center_y, center_x, height, width]->[xmin,ymin,xmax,ymax]
                    bboxes=np.zeros_like(bboxes_)
                    for i,box in enumerate(bboxes_):
                        bboxes[i][0] = box[1] - box[3] / 2.
                        bboxes[i][1] = box[0] - box[2] / 2.
                        bboxes[i][2] = box[1] + box[3] / 2.
                        bboxes[i][3] = box[0] + box[2] / 2.

                    if img.shape != config.img_shape:
                        bz_log.error('输出大小不对%d,%d,%d',img.shape[0], img.shape[1], img.shape[2] )
                        raise ValueError('输出大小不对')

                    #制作3种尺寸的标签
                    print(img_p)
                    label = generate_label_func(
                        np.concatenate((bboxes,np.expand_dims(cls_,1)),axis=-1),
                        num_classes=class_num
                    )
                    cv2.imwrite(img_calced_folder + img_p, img)
                    np.save(label_calced_folder+img_name+'.npy',label)
Пример #15
0
    def _one_ratio_augmentation(self):
        bz_log.info("进行一次数据增强")
        for i in range(len(self.img_list)):
            img_name, img_extension = bz_path.get_file_name(
                self.img_list[i], True)

            img_path, label = self.__create_data_fn(
                self.img_list[i], self.label_list[i], self.__copy, img_name,
                self.out_file_extension_list, False)
            self.augmentation_img_list = np.append(self.augmentation_img_list,
                                                   img_path)
            self.augmentation_label_list = np.append(
                self.augmentation_label_list, label)
        np.save(self.data_list_npy_path + '/img.npy',
                self.augmentation_img_list)
        if self.is_enhance_image_only:
            shutil.rmtree(self.augmentation_label_dir)
            return self.augmentation_img_list
        np.save(self.data_list_npy_path + '/label.npy',
                self.augmentation_label_list)

        print('数据增强完成!')
        return self.augmentation_img_list, self.augmentation_label_list
Пример #16
0
    def train(self):
        self.sess.run(tf.global_variables_initializer())
        ###保存模型并添加需要输出的量
        model_file = tf.train.latest_checkpoint(self.model_dir)
        # # 加载之前的训练数据继续训练
        if model_file is not None:
            print('load model:' + model_file)
            self.saver.restore(self.sess, model_file)
        bz_log.info("网络初始化成功,开始训练")
        loss_not_decrease_epoch_num = 0
        summary_writer = tf.summary.FileWriter(self.log_dir, graph=self.sess.graph)
        for epoch_index in range(self.epoch):
            print('epoch:', epoch_index)
            for step in range(self.step_num):
                global_step = epoch_index*self.image_num+step
                batch_image = self.sess.run(self.train_dataset)
                _, loss_, summary, logits_ = self.sess.run(
                    [self.train_op, self.loss_op, self.summary_op, self.logits],
                    feed_dict={self.inputs: batch_image['img']})

                step_index = step + epoch_index * self.step_num
                print("第"+  str(step_index) +"个step--:")
                print("loss:", loss_)
                summary_writer.add_summary(summary,global_step)
            #验证集
            eval_result = {}
            val_batch = self.sess.run(self.val_dataset)
            val_loss = self.sess.run(
                [self.loss_op],
                feed_dict={
                    self.inputs: val_batch['img']})
            eval_result['epoch_num'] = epoch_index + 1
            eval_result['loss'] = val_loss[0]
            eval_result["class_num"] = 255
            if self.is_socket:
                data_dict = list(eval_result.values())
                data_dict = str(data_dict).encode('utf-8')
                self.socket.send(data_dict)

            saved_model_value = eval_result['loss']
            print("eval loss------", saved_model_value)

            ckpt_file = self.model_dir + '/' + 'mm.ckpt'
            print("更新保存于:", ckpt_file)
            self.saver.save(self.sess, ckpt_file, global_step=epoch_index + 1)
            print(eval_result)

            # 模型保存的条件
            if saved_model_value < self.value_judgment:
                self.value_judgment = saved_model_value
                latest_ckpt = 'mm.ckpt-' + str(epoch_index + 1)
                print("保存最佳模型")
                self.save_best_checkpoint(latest_ckpt, self.best_checkpoint_dir,
                                          eval_result)
                print("导出pb模型")
                self.export_model(self.sess)

            # early stopping
            loss_tolerance = 0.0005
            if eval_result["loss"] - self.value_judgment >= loss_tolerance:
                loss_not_decrease_epoch_num += 1
            else:
                loss_not_decrease_epoch_num = 0
            if loss_not_decrease_epoch_num > 8:
                print("导出pb模型")
                self.export_model(self.sess)
                print("early stopping 共训练%d个epoch" % epoch_index)
                break

        if self.is_socket:
            self.socket.close()
Пример #17
0
def copy_and_split_train_val_data(original_data_path,
                                  out_path,
                                  min_example_num=20,
                                  ext_list=(['jpg'], ['npy']),
                                  task='classification'):
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    for sub_folder in bz_path.get_subfolder_path(original_data_path,
                                                 ret_full_path=False,
                                                 is_recursion=False):
        img_path = original_data_path + sub_folder + '/img/'
        label_path = original_data_path + sub_folder + '/label/'

        print(img_path)
        img_copy_path = out_path + '/original_data_copy/' + sub_folder + '/img/'
        label_copy_path = out_path + '/original_data_copy/' + sub_folder + '/label/'

        if not os.path.exists(img_copy_path):
            os.makedirs(out_path + '/original_data_copy/' + sub_folder +
                        '/img/')
            os.makedirs(out_path + '/original_data_copy/' + sub_folder +
                        '/label/')

        img_file_path_list = np.sort(
            np.array(bz_path.get_file_path(img_path, ret_full_path=True)))
        label_file_path_list = np.sort(
            np.array(bz_path.get_file_path(label_path, ret_full_path=True)))

        img_list_num = len(img_file_path_list)
        if img_list_num < min_example_num:
            if task == 'classification' or task == 'segmentation':
                data_aug = DataAugmentation(
                    img_file_path_list,
                    label_file_path_list,
                    augmentation_ratio=np.ceil(min_example_num / img_list_num),
                    generate_data_folder=out_path + '/generate_data/',
                    task=task,
                    out_file_extension_list=ext_list)
                data_aug.augment_data()

                for file_path in bz_path.get_file_path(
                        out_path + '/generate_data/augmentation_img/',
                        ret_full_path=True):
                    file_name, ext = bz_path.get_file_name(file_path,
                                                           return_ext=True)
                    # img = cv2.imread(file_path)
                    # cv2.imwrite(out_path + '/original_data_copy/' + sub_folder + '/img/' + file_name + '.jpg', img)
                    img = imread(file_path, 'RGB')
                    imwrite(
                        out_path + '/original_data_copy/' + sub_folder +
                        '/img/' + file_name + '.jpg', img)

                if task.lower() == 'classification':
                    for file_path in bz_path.get_file_path(
                            out_path + '/generate_data/augmentation_label/',
                            ret_full_path=False):
                        shutil.copy(
                            out_path + '/generate_data/augmentation_label/' +
                            file_path, out_path + '/original_data_copy/' +
                            sub_folder + '/label/' + file_path)
                elif task.lower() == 'segmentation':
                    for file_path in bz_path.get_file_path(
                            out_path + '/generate_data/augmentation_label/',
                            ret_full_path=True):
                        file_name, ext = bz_path.get_file_name(file_path,
                                                               return_ext=True)
                        # label = cv2.imread(file_path, 0)
                        # cv2.imwrite(out_path + '/original_data_copy/' + sub_folder + '/label/' + file_name + ".png", label)
                        label = imread(file_path, 'gray')
                        imwrite(
                            out_path + '/original_data_copy/' + sub_folder +
                            '/label/' + file_name + ".png", label)
            elif task == 'detection':
                print('样本小于min_example_num,进行预增强...')
                bz_log.info('样本小于min_example_num,进行预增强...')
                #进行txt的格式转换
                txt2npy_path = out_path + '/txt2npy/' + sub_folder + '/label/'
                if os.path.exists(txt2npy_path):
                    shutil.rmtree(txt2npy_path)
                os.makedirs(txt2npy_path)

                for file_path in bz_path.get_file_path(label_path,
                                                       ret_full_path=True):
                    file_name, ext = bz_path.get_file_name(file_path,
                                                           return_ext=True)
                    if ext == 'txt':  # 进行格式转换
                        with open(file_path, 'r') as f:
                            lines = f.readlines()
                        data = []
                        for line in lines:
                            temp = list(map(int, line.strip().split(',')))
                            data.append(
                                [temp[1], temp[0], temp[3], temp[2], temp[4]])
                        np.save(txt2npy_path + file_name + ".npy", data)
                if len(os.listdir(txt2npy_path)) != 0:
                    label_file_path_list = np.sort(
                        np.array(
                            bz_path.get_file_path(txt2npy_path,
                                                  ret_full_path=True)))

                yolo_min_example_augmentation_data = DataAugmentation(
                    img_list=img_file_path_list,
                    label_list=label_file_path_list,
                    channel=3,
                    augmentation_ratio=np.ceil(min_example_num / img_list_num),
                    # 增强倍数
                    generate_data_folder=out_path + '/generate_data/' +
                    sub_folder + os.sep,
                    task='detection')
                yolo_min_example_augmentation_data.augment_data()

                for file_path in bz_path.get_file_path(
                        out_path + '/generate_data/' + sub_folder +
                        '/augmentation_img/',
                        ret_full_path=True):
                    file_name, ext = bz_path.get_file_name(file_path,
                                                           return_ext=True)
                    img = cv2.imread(file_path)
                    cv2.imwrite(
                        out_path + '/original_data_copy/' + sub_folder +
                        '/img/' + file_name + '.jpg', img)

                for file_path in bz_path.get_file_path(
                        out_path + '/generate_data/' + sub_folder +
                        '/augmentation_label/',
                        ret_full_path=True):
                    file_name, ext = bz_path.get_file_name(file_path,
                                                           return_ext=True)
                    shutil.copy(
                        file_path, out_path + '/original_data_copy/' +
                        sub_folder + '/label/' + file_name + ".npy")
        else:

            for file_path in bz_path.get_file_path(img_path,
                                                   ret_full_path=True):
                file_name, ext = bz_path.get_file_name(file_path,
                                                       return_ext=True)
                # img = cv2.imread(file_path)
                # cv2.imwrite(out_path + '/original_data_copy/' + sub_folder + '/img/' + file_name + '.jpg', img)
                img = imread(file_path, 'rgb')
                imwrite(
                    out_path + '/original_data_copy/' + sub_folder + '/img/' +
                    file_name + '.jpg', img)

            if task.lower() == 'classification':
                for file_path in label_file_path_list:
                    file_name, ext = bz_path.get_file_name(file_path,
                                                           return_ext=True)
                    shutil.copy(
                        file_path, out_path + '/original_data_copy/' +
                        sub_folder + '/label/' + file_name + ".npy")
            elif task.lower() == 'segmentation':
                for file_path in bz_path.get_file_path(label_path,
                                                       ret_full_path=True):
                    file_name, ext = bz_path.get_file_name(file_path,
                                                           return_ext=True)
                    # label = cv2.imread(file_path, 0)
                    # cv2.imwrite(out_path + '/original_data_copy/' + sub_folder + '/label/' + file_name + ".png", label)
                    label = imread(file_path, 'gray')
                    imwrite(
                        out_path + '/original_data_copy/' + sub_folder +
                        '/label/' + file_name + ".png", label)
            elif task.lower() == 'detection':
                for file_path in bz_path.get_file_path(label_path,
                                                       ret_full_path=True):
                    file_name, ext = bz_path.get_file_name(file_path,
                                                           return_ext=True)
                    if ext == 'txt':  #进行格式转换
                        with open(file_path, 'r') as f:
                            lines = f.readlines()
                        data = []
                        for line in lines:
                            temp = list(map(int, line.strip().split(',')))
                            data.append(
                                [temp[1], temp[0], temp[3], temp[2], temp[4]])
                        np.save(
                            out_path + '/original_data_copy/' + sub_folder +
                            '/label/' + file_name + ".npy", data)
                    elif ext == 'npy':
                        shutil.copy(
                            file_path, out_path + '/original_data_copy/' +
                            sub_folder + '/label/' + file_name + ".npy")

        img_list, label_list = bz_path.get_img_label_path_list(
            img_copy_path, label_copy_path, ret_full_path=True)
        if task.lower() != 'segmentation' and task.lower() != 'detection':
            label_list = np.array(
                [np.loadtxt(label_path) for label_path in label_list])
        split_train_eval_test_data(img_list,
                                   label_list,
                                   out_path + '/train/' + sub_folder,
                                   out_path + '/val/' + sub_folder,
                                   out_path + '/test/' + sub_folder,
                                   task=task)
Пример #18
0
    def fit(self,
            train_features,
            train_labels,
            eval_features=None,
            eval_labels=None):
        """
        :param train_features: 训练图像路径
        :param train_labels: 训练标签路径
        :param eval_features: 验证图像路径
        :param eval_labels: 验证标签路径
        :return:
        """
        # 交叉验证

        # value = 0
        # 切换loss的epoch num
        change_loss_epoch = round(self.epoch_num *
                                  self.change_loss_fn_threshold)

        if self.k_fold > 1:
            print("交叉验证")
            for epoch_index in range(self.epoch_num):
                if epoch_index > self.epoch_num * self.change_loss_fn_threshold:
                    self.estimator_obj.use_background_and_foreground_loss = False
                data_list = cross_validation.create_cross_validation_data(
                    self.k_fold, train_features, train_labels)
                eval_result = {}
                cross_eval_result = {}
                for j in range(self.k_fold):
                    sub_train_features, \
                        sub_train_labels, \
                        sub_eval_features, \
                        sub_eval_labels = data_list[j]
                    self.estimator_obj.train(lambda: self.__train_input_fn(
                        sub_train_features, sub_train_labels))
                    cross_eval_result = self.estimator_obj.evaluate(
                        lambda: self.__eval_input_fn(sub_eval_features,
                                                     sub_eval_labels))

                    for key, value in cross_eval_result.items():
                        if key not in ['global_step']:
                            if j == 0:
                                eval_result[key] = value / self.k_fold
                            else:
                                eval_result[key] += value / self.k_fold
                eval_result['global_step'] = cross_eval_result['global_step']

                print('\033[1;36m 交叉验证结果:epoch_index=' + str(epoch_index))
                for k, v in eval_result.items():
                    print(k + ' =', v)
                print('\033[0m')

                if self.is_socket:
                    eval_result['epoch_num'] = epoch_index + 1
                    eval_result["class_num"] = self.class_num
                    data_dict = list(eval_result.values())
                    data_dict = str(data_dict).encode('utf-8')
                    self.socket.send(data_dict)

                saved_model_value = self.calculate_saved_model_value_callback(
                    *self.accuracy_weight, **eval_result)
                # 模型保存的条件
                if saved_model_value > self.value_judgment:
                    self.value_judgment = saved_model_value
                    eval_result['value_judgment'] = self.value_judgment
                    self.export_model_dir = self.model_dir + '/export_model_dir'
                    self.export_model(export_model_dir=self.export_model_dir)
                    with open(
                            self.best_checkpoint_dir + '/best_model_info.json',
                            'w') as f:
                        eval_result_dict = {
                            k: str(v)
                            for k, v in eval_result.items()
                        }
                        json.dump(eval_result_dict, f, indent=1)

        else:
            loss_not_decrease_epoch_num = 0
            for epoch_index in range(0, self.epoch_num, self.eval_epoch_step):
                if eval_features is None or eval_labels is None:
                    raise ValueError('非交叉验证时必须输入验证集!')
                if epoch_index > self.epoch_num * self.change_loss_fn_threshold:
                    self.estimator_obj.use_background_and_foreground_loss = False
                self.estimator_obj.train(lambda: self.__train_input_fn(
                    train_features, train_labels))
                eval_result = self.estimator_obj.evaluate(
                    lambda: self.__eval_input_fn(eval_features, eval_labels))
                print("获得验证结果,开始数据传输")
                print('\033[1;36m 验证集结果:epoch_index=' + str(epoch_index))
                for k, v in eval_result.items():
                    print(k + ' =', v)

                print('\033[0m')
                if self.is_socket:
                    eval_result[
                        'epoch_num'] = epoch_index / self.eval_epoch_step + 1
                    eval_result["class_num"] = self.class_num
                    data_dict = list(eval_result.values())
                    data_dict = str(data_dict).encode('utf-8')
                    self.socket.send(data_dict)

                # saved_model_value = self.calculate_saved_model_value_callback(*self.accuracy_weight, **eval_result)
                saved_model_value = eval_result['loss']
                # 模型保存的条件
                if saved_model_value < self.value_judgment:
                    self.value_judgment = saved_model_value
                    eval_result['value_judgment'] = self.value_judgment
                    self.export_model_dir = self.model_dir + '/export_model_dir'
                    self.export_model(export_model_dir=self.export_model_dir,
                                      eval_result=eval_result)

                # early stopping
                if (self.is_early_stop and epoch_index > change_loss_epoch):
                    loss_tolerance = 0.0005
                    if eval_result[
                            "loss"] - self.value_judgment >= loss_tolerance:
                        loss_not_decrease_epoch_num += 1
                    else:
                        loss_not_decrease_epoch_num = 0
                    if loss_not_decrease_epoch_num > 5:
                        bz_log.info("early stopping 共训练%d个epoch%d",
                                    epoch_index)
                        bz_log.info("is early stop%s", self.is_early_stop)
                        print("early stopping 共训练%d个epoch" % epoch_index)
                        break

        if self.is_socket:
            # self.socket.send(b'exit')
            self.socket.close()
Пример #19
0
    def fit(self, train_records, eval_records):
        # recall_value, loss_value = 0, 1000
        # 模型保存条件与步数控制
        if os.path.exists(self.model_dir + os.sep +
                          'best_checkpoint_dir/eval_metrics.json'):
            with open(self.model_dir + os.sep +
                      'best_checkpoint_dir/eval_metrics.json') as f:
                last_eval_res = json.load(f)
            loss_value = float(last_eval_res['loss'])
        else:
            loss_value = 10000
        loss_not_decrease_epoch_num = 0
        for epoch_index in range(0, self.epoch_num, self.eval_epoch_step):
            self.estimator_obj.train(
                input_fn=lambda: self.train_input_fn(train_records))
            # if epoch_index % self.eval_epoch_step ==0:
            eval_result = self.estimator_obj.evaluate(
                input_fn=lambda: self.eval_input_fn(eval_records))
            print('第', epoch_index, '个epoch的验证结果:')
            for k, v in eval_result.items():
                print(k, ':', v)

            #通过socket传输eval_result
            if self.is_socket:
                bz_log.info("epoch_num%d:", epoch_index + 1)
                eval_result[
                    'epoch_num'] = epoch_index / self.eval_epoch_step + 1
                eval_result["class_num"] = self.class_num
                data_dict = list(eval_result.values())[1:]
                data_dict = str(data_dict).encode('utf-8')
                self.socket.send(data_dict)

            # 模型更新与保存,两个判断条件并存
            # eval_recall_value, eval_loss_value = self.__get_saved_model_value(eval_result)
            eval_loss_value = self.__get_saved_model_value(eval_result)
            # if eval_recall_value>=recall_value or eval_loss_value <= loss_value:
            if eval_loss_value <= loss_value:
                #导出模型
                self.export_model()
                with open(self.best_checkpoint_dir + '/eval_metrics.json',
                          'w') as f:
                    eval_result_dict = {
                        k: str(v)
                        for k, v in eval_result.items()
                    }
                    json.dump(eval_result_dict, f, indent=1)
                # recall_value,loss_value=eval_recall_value,eval_loss_value
                loss_value = eval_loss_value

            # early stopping
            if (self.is_early_stop):
                loss_tolerance = 0.0005
                if eval_result["loss"] - loss_value >= loss_tolerance:
                    loss_not_decrease_epoch_num += 1
                else:
                    loss_not_decrease_epoch_num = 0
                if loss_not_decrease_epoch_num > 5:
                    print("early stopping 共训练%d个epoch" % epoch_index)
                    break

        #保存freeze pb模型,在训练终止时进行转换
        self.convert_export_model_to_pb()