Beispiel #1
0
    def __init__(self,
                 input_shape=(224, 224, 3),
                 datasets_dir=None,
                 datasets_zip=None,
                 unpack_dir=None,
                 logger=None,
                 max_classes_limit=15,
                 one_class_min_images_num=100,
                 one_class_max_images_num=2000,
                 allow_reshape=True,
                 support_shapes=((224, 224, 3), (240, 240, 3))):
        '''
            input_shape: input shape (height, width)
            min_images_num: min image number in one class
        '''
        import tensorflow as tf  # for multiple process
        self.tf = tf
        self.need_rm_datasets = False
        self.input_shape = input_shape
        self.support_shapes = support_shapes
        if not self.input_shape in self.support_shapes:
            raise Exception(
                "input shape {} not support, only support: {}".format(
                    self.input_shape, self.support_shapes))
        self.allow_reshape = allow_reshape  # if dataset image's shape not the same as require's, reshape it
        self.config_max_classes_limit = max_classes_limit
        self.config_one_class_min_images_num = one_class_min_images_num
        self.config_one_class_max_images_num = one_class_max_images_num
        self.datasets_rm_dir = None
        self.model = None
        self.history = None
        self.warning_msg = []  # append warning message here
        if logger:
            self.log = logger
        else:
            self.log = Fake_Logger()
        # unzip datasets
        if datasets_zip:
            self.datasets_list = [
                self._unpack_datasets(datasets_zip, unpack_dir)
            ]
            if not self.datasets_list:
                self.log.e("can't detect datasets, check zip format")
                raise Exception("can't detect datasets, check zip format")
        elif datasets_dir:
            if os.path.isdir(datasets_dir):
                self.datasets_list = [datasets_dir]
            else:
                with open(datasets_dir, "r") as f:
                    data = json.load(f)
                self.datasets_list = data["train"]
                self.datasets_val_list = data["val"]
        else:
            self.log.e("no datasets args")
            raise Exception("no datasets args")
        # parse train datasets
        print("Scanning train datasets")
        self.labels = []
        classes_data_counts = []
        datasets_x = []
        datasets_y = []
        for dir in self.datasets_list:
            ok, msg, self.labels, _classes_data_counts, _datasets_x, _datasets_y = self._load_datasets(
                dir.replace("\r", "").replace("\n", ""))
            if not ok:
                msg = f"datasets format error: {msg}"
                self.log.e(msg)
                raise Exception(msg)
            if (len(classes_data_counts) == 0):
                classes_data_counts = _classes_data_counts
            else:
                for i in range(0, len(_classes_data_counts)):
                    classes_data_counts[i] += _classes_data_counts[i]

            datasets_x.extend(_datasets_x)
            datasets_y.extend(_datasets_y)

        # check train datasets
        if not (self.datasets_val_list):
            return

        ok, err_msg = self._is_datasets_valid(
            self.labels,
            classes_data_counts,
            one_class_min_images_num=self.config_one_class_min_images_num,
            one_class_max_images_num=self.config_one_class_max_images_num)
        if not ok:
            self.log.e(err_msg)
            raise Exception(err_msg)
        self.log.i(
            "load train datasets complete, check pass, images num:{}, bboxes num:{}"
            .format(len(datasets_x), sum(classes_data_counts)))
        self.datasets_img = np.array(datasets_x, dtype='uint8')
        self.datasets_ann = datasets_y

        # parse val datasets
        print("Scanning val datasets")
        datasets_x = []
        datasets_y = []
        classes_data_counts = []
        for dir in self.datasets_val_list:
            ok, msg, self.labels, _classes_data_counts, _datasets_x, _datasets_y = self._load_datasets(
                dir.replace("\r", "").replace("\n", ""))
            if not ok:
                msg = f"datasets format error: {msg}"
                self.log.e(msg)
                raise Exception(msg)

            if (len(classes_data_counts) == 0):
                classes_data_counts = _classes_data_counts
            else:
                for i in range(0, len(_classes_data_counts)):
                    classes_data_counts[i] += _classes_data_counts[i]

            datasets_x.extend(_datasets_x)
            datasets_y.extend(_datasets_y)

        self.log.i(
            "load val datasets complete, check pass, images num:{}, bboxes num:{}"
            .format(len(datasets_x), sum(classes_data_counts)))
        self.datasets_val_img = np.array(datasets_x, dtype='uint8')
        self.datasets_val_ann = datasets_y

        class _Train_progress_cb(tf.keras.callbacks.Callback):  #剩余训练时间回调
            def __init__(self, epochs, user_progress_callback, logger):
                self.epochs = epochs
                self.logger = logger
                self.user_progress_callback = user_progress_callback

            def on_epoch_begin(self, epoch, logs=None):
                self.logger.i("epoch {} start".format(epoch))

            def on_epoch_end(self, epoch, logs=None):
                self.logger.i("epoch {} end: {}".format(epoch, logs))
                if self.user_progress_callback:
                    self.user_progress_callback(
                        (epoch + 1) / self.epochs * 100, "train epoch end")

            def on_train_begin(self, logs=None):
                self.logger.i("train start")
                if self.user_progress_callback:
                    self.user_progress_callback(0, "train start")

            def on_train_end(self, logs=None):
                self.logger.i("train end")
                if self.user_progress_callback:
                    self.user_progress_callback(100, "train end")

        self.Train_progress_cb = _Train_progress_cb
Beispiel #2
0
class Detector(Train_Base):
    def __init__(self,
                 input_shape=(224, 224, 3),
                 datasets_dir=None,
                 datasets_zip=None,
                 unpack_dir=None,
                 logger=None,
                 max_classes_limit=15,
                 one_class_min_images_num=100,
                 one_class_max_images_num=2000,
                 allow_reshape=True,
                 support_shapes=((224, 224, 3), (240, 240, 3))):
        '''
            input_shape: input shape (height, width)
            min_images_num: min image number in one class
        '''
        import tensorflow as tf  # for multiple process
        self.tf = tf
        self.need_rm_datasets = False
        self.input_shape = input_shape
        self.support_shapes = support_shapes
        if not self.input_shape in self.support_shapes:
            raise Exception(
                "input shape {} not support, only support: {}".format(
                    self.input_shape, self.support_shapes))
        self.allow_reshape = allow_reshape  # if dataset image's shape not the same as require's, reshape it
        self.config_max_classes_limit = max_classes_limit
        self.config_one_class_min_images_num = one_class_min_images_num
        self.config_one_class_max_images_num = one_class_max_images_num
        self.datasets_rm_dir = None
        self.model = None
        self.history = None
        self.warning_msg = []  # append warning message here
        if logger:
            self.log = logger
        else:
            self.log = Fake_Logger()
        # unzip datasets
        if datasets_zip:
            self.datasets_list = [
                self._unpack_datasets(datasets_zip, unpack_dir)
            ]
            if not self.datasets_list:
                self.log.e("can't detect datasets, check zip format")
                raise Exception("can't detect datasets, check zip format")
        elif datasets_dir:
            if os.path.isdir(datasets_dir):
                self.datasets_list = [datasets_dir]
            else:
                with open(datasets_dir, "r") as f:
                    data = json.load(f)
                self.datasets_list = data["train"]
                self.datasets_val_list = data["val"]
        else:
            self.log.e("no datasets args")
            raise Exception("no datasets args")
        # parse train datasets
        print("Scanning train datasets")
        self.labels = []
        classes_data_counts = []
        datasets_x = []
        datasets_y = []
        for dir in self.datasets_list:
            ok, msg, self.labels, _classes_data_counts, _datasets_x, _datasets_y = self._load_datasets(
                dir.replace("\r", "").replace("\n", ""))
            if not ok:
                msg = f"datasets format error: {msg}"
                self.log.e(msg)
                raise Exception(msg)
            if (len(classes_data_counts) == 0):
                classes_data_counts = _classes_data_counts
            else:
                for i in range(0, len(_classes_data_counts)):
                    classes_data_counts[i] += _classes_data_counts[i]

            datasets_x.extend(_datasets_x)
            datasets_y.extend(_datasets_y)

        # check train datasets
        if not (self.datasets_val_list):
            return

        ok, err_msg = self._is_datasets_valid(
            self.labels,
            classes_data_counts,
            one_class_min_images_num=self.config_one_class_min_images_num,
            one_class_max_images_num=self.config_one_class_max_images_num)
        if not ok:
            self.log.e(err_msg)
            raise Exception(err_msg)
        self.log.i(
            "load train datasets complete, check pass, images num:{}, bboxes num:{}"
            .format(len(datasets_x), sum(classes_data_counts)))
        self.datasets_img = np.array(datasets_x, dtype='uint8')
        self.datasets_ann = datasets_y

        # parse val datasets
        print("Scanning val datasets")
        datasets_x = []
        datasets_y = []
        classes_data_counts = []
        for dir in self.datasets_val_list:
            ok, msg, self.labels, _classes_data_counts, _datasets_x, _datasets_y = self._load_datasets(
                dir.replace("\r", "").replace("\n", ""))
            if not ok:
                msg = f"datasets format error: {msg}"
                self.log.e(msg)
                raise Exception(msg)

            if (len(classes_data_counts) == 0):
                classes_data_counts = _classes_data_counts
            else:
                for i in range(0, len(_classes_data_counts)):
                    classes_data_counts[i] += _classes_data_counts[i]

            datasets_x.extend(_datasets_x)
            datasets_y.extend(_datasets_y)

        self.log.i(
            "load val datasets complete, check pass, images num:{}, bboxes num:{}"
            .format(len(datasets_x), sum(classes_data_counts)))
        self.datasets_val_img = np.array(datasets_x, dtype='uint8')
        self.datasets_val_ann = datasets_y

        class _Train_progress_cb(tf.keras.callbacks.Callback):  #剩余训练时间回调
            def __init__(self, epochs, user_progress_callback, logger):
                self.epochs = epochs
                self.logger = logger
                self.user_progress_callback = user_progress_callback

            def on_epoch_begin(self, epoch, logs=None):
                self.logger.i("epoch {} start".format(epoch))

            def on_epoch_end(self, epoch, logs=None):
                self.logger.i("epoch {} end: {}".format(epoch, logs))
                if self.user_progress_callback:
                    self.user_progress_callback(
                        (epoch + 1) / self.epochs * 100, "train epoch end")

            def on_train_begin(self, logs=None):
                self.logger.i("train start")
                if self.user_progress_callback:
                    self.user_progress_callback(0, "train start")

            def on_train_end(self, logs=None):
                self.logger.i("train end")
                if self.user_progress_callback:
                    self.user_progress_callback(100, "train end")

        self.Train_progress_cb = _Train_progress_cb

    def __del__(self):
        if self.need_rm_datasets:
            try:
                shutil.rmtree(self.datasets_list)
                self.log.i(f"clean temp dataset dir:{self.datasets_list}")
            except Exception as e:
                try:
                    self.log.e("clean temp files error:{}".format(e))
                except Exception:
                    print(
                        "log object invalid, var scope usage error, check code"
                    )

    def _get_anchors(self,
                     bboxes_in,
                     input_shape=(224, 224),
                     clusters=5,
                     strip_size=32):
        '''
            @input_shape tuple (h, w)
            @bboxes_in format: [ [[xmin,ymin, xmax, ymax, label],], ]
                        value range: x [0, w], y [0, h]
            @return anchors, format: 10 value tuple
        '''
        w = input_shape[1]
        h = input_shape[0]
        # TODO: add position to iou, not only box size
        bboxes = []
        for items in bboxes_in:
            for bbox in items:
                bboxes.append(
                    ((bbox[2] - bbox[0]) / w, (bbox[3] - bbox[1]) / h))
        bboxes = np.array(bboxes)
        self.log.i(f"bboxes num: {len(bboxes)}, first bbox: {bboxes[0]}")
        out = kmeans.kmeans(bboxes, k=clusters)
        iou = kmeans.avg_iou(bboxes, out) * 100
        self.log.i("bbox accuracy(IOU): {:.2f}%".format(iou))
        self.log.i("bound boxes: {}".format(",".join(
            "({:f},{:.2f})".format(item[0] * w, item[1] * h) for item in out)))
        for i, wh in enumerate(out):
            out[i][0] = wh[0] * w / strip_size
            out[i][1] = wh[1] * h / strip_size
        anchors = list(out.flatten())
        self.log.i(f"anchors: {anchors}")
        ratios = np.around(out[:, 0] / out[:, 1], decimals=2).tolist()
        self.log.i("w/h ratios: {}".format(sorted(ratios)))
        return anchors

    def train(
        self,
        epochs=100,
        progress_cb=None,
        weights=os.path.join(curr_file_dir, "weights",
                             "mobilenet_7_5_224_tf_no_top.h5"),
        batch_size=5,
        train_times=5,
        valid_times=2,
        learning_rate=1e-4,
        jitter=False,
        is_only_detect=False,
        save_best_weights_path="out/best_weights.h5",
        save_final_weights_path="out/final_weights.h5",
    ):
        import tensorflow as tf
        from yolo.frontend import create_yolo

        self.log.i("train, labels:{}".format(self.labels))
        self.log.d("train, datasets dir:{}".format(self.datasets_list))

        # param check
        # TODO: check more param
        if len(self.labels) == 1:
            is_only_detect = True
        self.save_best_weights_path = save_best_weights_path
        self.save_final_weights_path = save_final_weights_path

        # create yolo model
        strip_size = 32 if min(self.input_shape[:2]) % 32 == 0 else 16
        # get anchors
        self.anchors = self._get_anchors(self.datasets_ann,
                                         self.input_shape[:2],
                                         strip_size=strip_size)
        # create network
        yolo = create_yolo(architecture="MobileNet",
                           labels=self.labels,
                           input_size=self.input_shape[:2],
                           anchors=self.anchors,
                           coord_scale=1.0,
                           class_scale=1.0,
                           object_scale=5.0,
                           no_object_scale=1.0,
                           weights=weights,
                           strip_size=strip_size)

        # train
        self.history = yolo.train(
            img_folder=None,
            ann_folder=None,
            img_in_mem=self.datasets_img,  # datasets in mem, format: list
            ann_in_mem=self.
            datasets_ann,  # datasets's annotation in mem, format: list
            nb_epoch=epochs,
            save_best_weights_path=save_best_weights_path,
            save_final_weights_path=save_final_weights_path,
            batch_size=batch_size,
            jitter=jitter,
            learning_rate=learning_rate,
            train_times=train_times,
            valid_times=valid_times,
            valid_img_folder="",
            valid_ann_folder="",
            valid_img_in_mem=self.datasets_val_img,
            valid_ann_in_mem=self.datasets_val_ann,
            first_trainable_layer=None,
            is_only_detect=is_only_detect,
            progress_callbacks=[
                self.Train_progress_cb(epochs, progress_cb, self.log)
            ])

    def report(self, out_path, limit_y_range=None):
        '''
            generate result charts
        '''
        self.log.i("generate report image")
        if not self.history:
            return
        history = self.history
        print(history)

        # set for server with no Tkagg GUI support, use agg(non-GUI backend)
        plt.switch_backend('agg')

        fig, axes = plt.subplots(1,
                                 1,
                                 constrained_layout=True,
                                 figsize=(16, 10),
                                 dpi=100)
        if limit_y_range:
            plt.ylim(limit_y_range)

        # acc and val_acc
        # {'loss': [0.5860330664989357, 0.3398533443955177], 'accuracy': [0.70944744, 0.85026735], 'val_loss': [0.4948340670338699, 0.49342870752194096], 'val_accuracy': [0.7, 0.74285716]}
        if "acc" in history.history:
            kws = {
                "acc": "acc",
                "val_acc": "val_acc",
                "loss": "loss",
                "val_loss": "val_loss"
            }
        else:
            kws = {
                "acc": "accuracy",
                "val_acc": "val_accuracy",
                "loss": "loss",
                "val_loss": "val_loss"
            }
        # axes[0].plot( history.history[kws['acc']], color='#2886EA', label="train")
        # axes[0].plot( history.history[kws['val_acc']], color = '#3FCD6D', label="valid")
        # axes[0].set_title('model accuracy')
        # axes[0].set_ylabel('accuracy')
        # axes[0].set_xlabel('epoch')
        # axes[0].locator_params(integer=True)
        # axes[0].legend()

        # loss and val_loss
        axes.plot(history.history[kws['loss']], color='#2886EA', label="train")
        axes.plot(history.history[kws['val_loss']],
                  color='#3FCD6D',
                  label="valid")
        axes.set_title('model loss')
        axes.set_ylabel('loss')
        axes.set_xlabel('epoch')
        axes.locator_params(integer=True)
        axes.legend()

        # confusion matrix
        # cm, labels_idx = self._get_confusion_matrix()
        # axes[2].imshow(cm, interpolation='nearest', cmap = plt.cm.GnBu)
        # axes[2].set_title("confusion matrix")
        # # axes[2].colorbar()
        # num_local = np.array(range(len(labels_idx)))
        # axes[2].set_xticks(num_local)
        # axes[2].set_xticklabels(labels_idx.keys(), rotation=45)
        # axes[2].set_yticks(num_local)
        # axes[2].set_yticklabels(labels_idx.keys())

        # thresh = cm.max() / 2. # front color black or white according to the background color
        # for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        #     axes[2].text(j, i, format(cm[i, j], 'd'),
        #             horizontalalignment = 'center',
        #             color = 'white' if cm[i, j] > thresh else "black")
        # axes[2].set_ylabel('True label')
        # axes[2].set_xlabel('Predicted label')

        # save to fs
        fig.savefig(out_path)
        plt.close()
        self.log.i("generate report image end")

    def save(self, h5_path=None, tflite_path=None):
        src_h5_path = self.save_best_weights_path
        if h5_path:
            shutil.copyfile(src_h5_path, h5_path)
        if tflite_path:
            print("save tfilte to :", tflite_path)
            import tensorflow as tf
            # converter = tf.lite.TFLiteConverter.from_keras_model(model)
            # tflite_model = converter.convert()
            # with open (tflite_path, "wb") as f:
            #     f.write(tflite_model)

            ## kpu V3 - nncase = 0.1.0rc5
            # model.save("weights.h5", include_optimizer=False)
            model = tf.keras.models.load_model(src_h5_path)
            tf.compat.v1.disable_eager_execution()
            converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
                src_h5_path,
                output_arrays=[
                    '{}/BiasAdd'.format(model.get_layer(None, -2).name)
                ])
            tfmodel = converter.convert()
            with open(tflite_path, "wb") as f:
                f.write(tfmodel)
        # if h5_path:
        #     self.log.i("save model as .h5 file")
        #     if not h5_path.endswith(".h5"):
        #         if os.path.isdir(h5_path):
        #             h5_path = os.path.join(h5_path, "classifier.h5")
        #         else:
        #             h5_path += ".h5"
        #     if not self.model:
        #         raise Exception("no model defined")
        #     self.model.save(h5_path)
        # if tflite_path:
        #     self.log.i("save model as .tflite file")
        #     if not tflite_path.endswith(".tflite"):
        #         if os.path.isdir(tflite_path):
        #             tflite_path = os.path.join(tflite_path, "classifier.tflite")
        #         else:
        #             tflite_path += ".tflite"
        #     import tensorflow as tf
        #     converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
        #     tflite_model = converter.convert()
        #     with open (tflite_path, "wb") as f:
        #         f.write(tflite_model)

    def infer(self, input):
        pass

    def get_sample_images(self, sample_num, copy_to_dir):
        from PIL import Image
        if self.datasets_img is None:
            raise Exception("datasets dir not exists")
        indxes = np.random.choice(range(self.datasets_img.shape[0]),
                                  sample_num,
                                  replace=False)
        for i in indxes:
            img = self.datasets_img[i]
            path = os.path.join(copy_to_dir, f"image_{i}.jpg")
            img = Image.fromarray(img)
            img.save(path)
        # num_gen = self._get_sample_num(len(self.labels), sample_num)
        # for label in self.labels:
        #     num = num_gen.__next__()
        #     images = os.listdir(os.path.join(self.datasets_dir, label))
        #     images = random.sample(images, num)
        #     for image in images:
        #         shutil.copyfile(os.path.join(self.datasets_dir, label, image), os.path.join(copy_to_dir, image))

    def _get_confusion_matrix(self, ):
        batch_size = 5
        from tensorflow.keras.preprocessing.image import ImageDataGenerator
        from tensorflow.keras.applications.mobilenet import preprocess_input
        valid_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
        valid_data = valid_gen.flow_from_directory(
            self.datasets_list,
            target_size=[self.input_shape[0], self.input_shape[1]],
            color_mode='rgb',
            batch_size=batch_size,
            class_mode='sparse',
            shuffle=False)
        prediction = self.model.predict_generator(valid_data,
                                                  steps=valid_data.samples //
                                                  batch_size,
                                                  verbose=1)
        predict_labels = np.argmax(prediction, axis=1)
        true_labels = valid_data.classes
        if len(predict_labels) != len(true_labels):
            true_labels = true_labels[0:len(predict_labels)]
        cm = confusion_matrix(true_labels, predict_labels)
        return cm, valid_data.class_indices

    def _unpack_datasets(self,
                         datasets_zip,
                         datasets_dir=None,
                         rm_dataset=True):
        '''
            uppack zip datasets to /temp, make /temp as tmpfs is recommend
            zip should be: 
                            datasets
                                   |
                                    ---- tfrecord1
                                   |
                                    ---- tfrecord1
            or: 
                        ---- tfrecord1
                        ---- tfrecord1
        '''
        if not datasets_dir:
            datasets_dir = os.path.join(tempfile.gettempdir(),
                                        "detector_datasets")
            if rm_dataset:
                self.datasets_rm_dir = datasets_dir
                self.need_rm_datasets = True
        if not os.path.exists(datasets_dir):
            os.makedirs(datasets_dir)
        zip_file = zipfile.ZipFile(datasets_zip)
        for names in zip_file.namelist():
            zip_file.extract(names, datasets_dir)
        zip_file.close()
        dirs = []
        for d in os.listdir(datasets_dir):
            if d.startswith(".") or not os.path.isdir(
                    os.path.join(datasets_dir, d)):
                continue
            dirs.append(d)
        if len(dirs) == 1:  # sub dir
            root_dir = dirs[0]
            datasets_dir = os.path.join(datasets_dir, root_dir)
        elif len(dirs) == 0:  # no sub dir
            pass
        else:  # multiple folder, not support
            return None
        return datasets_dir

    def _check_update_input_shape(self, img_shape):
        '''
            this will change self.input_shape according to img_shape if suppport
        '''
        if not img_shape in self.support_shapes:
            return False
        self.input_shape = img_shape
        self.log.i(f"input_shape: {self.input_shape}")
        return True

    def _load_datasets(self, datasets_dir):
        '''
            load datasets, support format:
                TFRecord: tfrecord files and tf_label_map.pbtxt in datasets_dir
            @return ok, msg, labels, classes_data_counts, datasets_x, datasets_y
                    classes_data_counts: every class's dataset count, format list, index the same as label's
                    datasets_x: np.ndarray images, not normalize, RGB channel value: [0, 255]
                    datasets_y: np.ndarray bboxes and labels index for one image, format: [[xmin, ymin, xmax, ymax, label_index], ]
                                value range:[0, w] [0, h], not [0, 1]
            @attention self.input_shape can be modified in this function according to the datasets                        
        '''
        def is_tfrecord():
            label_file_name = "tf_label_map.pbtxt"
            label_file_path = os.path.join(datasets_dir, label_file_name)
            if os.path.exists(label_file_path):
                return True
            return False

        def is_pascal_voc():
            dirs = os.listdir(datasets_dir)
            if "images" in dirs and "xml" in dirs and "labels.txt" in dirs:
                return True
            return False

        # detect datasets type
        # tfrecord
        if is_tfrecord():
            return self._load_datasets_tfrecord(datasets_dir)
        elif is_pascal_voc():
            return self._load_datasets_pascal_voc(datasets_dir)
        return False, "datasets error, not support format, please check", [], None, None, None

    def _load_datasets_tfrecord(self, datasets_dir):
        '''
            load tfrecord, param and return the same as _load_datasets's
        '''
        def decode_img(img_bytes):
            img = None
            msg = ""
            try:
                # TODO: remove this condition if vott fixed this issue: https://github.com/microsoft/VoTT/issues/1012
                if b'image/encoded' in img_bytes:
                    img_bytes = img_bytes[42:]
                # TODO: check image sha256
                img = self.tf.io.decode_jpeg(img_bytes).numpy()
            except Exception as e:
                msg = "decode image {} error: {}".format(file_name, e)
                self.on_warning_message(msg)
            return img, msg

        labels = []
        datasets_x = []
        datasets_y = []
        # tfrecord
        # tf_label_map.pbtxt file
        label_file_name = "tf_label_map.pbtxt"
        label_file_path = os.path.join(datasets_dir, label_file_name)
        if not os.path.exists(label_file_path):
            return False, f"no file {label_file_name} exists", [], None, None, None
        try:
            labels = self._decode_pbtxt_file(label_file_path)
            self.log.i(f"labels: {labels}")
        except Exception as e:
            return False, str(e), [], None, None, None
        # check labels
        ok, msg = self._is_labels_valid(labels)
        if not ok:
            return False, msg, [], None, None, None
        labels_len = len(labels)
        if labels_len < 1:
            return False, 'no classes find', [], None, None, None
        if labels_len > self.config_max_classes_limit:
            return False, 'classes too much, limit:{}, datasets:{}'.format(
                self.config_max_classes_limit,
                len(labels)), [], None, None, None

        # *.tfrecord file
        tfrecord_files = []
        classes_data_counts = [0] * labels_len
        for name in os.listdir(datasets_dir):
            path = os.path.join(datasets_dir, name)
            if (name.startswith(".") or name == "__pycache__"
                    or os.path.isdir(path) or not path.endswith(".tfrecord")):
                continue
            tfrecord_files.append(path)
        # parse tfrecord file
        self.log.i("detect {} tfrecord files".format(len(tfrecord_files)))
        raws = self.tf.data.TFRecordDataset(tfrecord_files)
        # for raw in raws:
        #     example = self.tf.train.Example()
        #     example.ParseFromString(raw.numpy())
        #     print(example)
        feature_description = {
            "image/encoded": self.tf.io.FixedLenFeature([], self.tf.string),
            "image/filename": self.tf.io.FixedLenFeature([], self.tf.string),
            # "image/format": self.tf.io.FixedLenFeature([], self.tf.string),
            "image/width": self.tf.io.FixedLenFeature([], self.tf.int64),
            "image/height": self.tf.io.FixedLenFeature([], self.tf.int64),
            "image/object/class/label":
            self.tf.io.VarLenFeature(self.tf.int64),
            "image/object/class/text":
            self.tf.io.VarLenFeature(self.tf.string),
            "image/object/bbox/xmin":
            self.tf.io.VarLenFeature(self.tf.float32),
            "image/object/bbox/ymin":
            self.tf.io.VarLenFeature(self.tf.float32),
            "image/object/bbox/xmax":
            self.tf.io.VarLenFeature(self.tf.float32),
            "image/object/bbox/ymax":
            self.tf.io.VarLenFeature(self.tf.float32),
        }

        def _parse_func(example_proto):
            # Parse the input tf.Example proto using the dictionary above.
            return self.tf.io.parse_single_example(example_proto,
                                                   feature_description)

        parsed_dataset = raws.map(_parse_func)
        input_shape_checked = False
        for record in parsed_dataset:
            # print(record["image/width"].numpy())
            # print(record["image/object/class/label"].values)
            # print(record["image/object/bbox/xmin"].values)
            # print(record['image/filename'])
            # print(record['image/encoded'])
            file_name = record['image/filename'].numpy().decode()
            img_shape = (record["image/height"].numpy(),
                         record["image/width"].numpy())
            y_labels = record["image/object/class/label"].values
            y_labels_txt = record["image/object/class/text"].values
            y_bboxes_xmin = record["image/object/bbox/xmin"].values * img_shape[
                1]  # range [0, 1] to [0, w], float32 dtype, no need convert to int
            y_bboxes_ymin = record["image/object/bbox/ymin"].values * img_shape[
                0]  # range [0, 1] to [0, h], float32 dtype
            y_bboxes_xmax = record[
                "image/object/bbox/xmax"].values * img_shape[1]
            y_bboxes_ymax = record[
                "image/object/bbox/ymax"].values * img_shape[0]

            shape_valid = True
            if not input_shape_checked:
                img, msg = decode_img(record['image/encoded'].numpy())
                if img is None:
                    continue
                if not self._check_update_input_shape(
                        img.shape) and not self.allow_reshape:
                    return False, "not supported input size: {}, supported: {}".format(
                        img.shape, self.support_shapes), [], None, None, None
                input_shape_checked = True
            # check image shape
            if img_shape != self.input_shape[:2]:
                shape_valid = False
                msg = "image {} shape not valid, input:{}, require:{}".format(
                    file_name, img_shape, self.input_shape)
                self.on_warning_message(msg)
                if not self.allow_reshape:
                    # not allow reshape, drop this image
                    continue
            # bboxes,
            y_bboxes = []
            for i in range(len(y_labels)):
                # check label in labels
                label_txt = y_labels_txt[i].numpy().decode()
                if (not label_txt in labels) or \
                    (labels.index(label_txt) != y_labels[i].numpy()) : # text in labels and index the same
                    msg = "image {}'s label error: label {}:{} error, maybe pbtxt file error if use TFRecord".format(
                        file_name, y_labels[i].numpy(), label_txt)
                    self.on_warning_message(msg)
                    continue
                y_bboxes.append([
                    y_bboxes_xmin[i].numpy(), y_bboxes_ymin[i].numpy(),
                    y_bboxes_xmax[i].numpy(), y_bboxes_ymax[i].numpy(),
                    y_labels[i].numpy()
                ])
                classes_data_counts[y_labels[i].numpy()] += 1
            # no bbox, next
            if len(y_bboxes) < 1:
                continue
            # image decode
            img, msg = decode_img(record['image/encoded'].numpy())
            if img is None:
                continue
            # check image shape again
            if img.shape != self.input_shape:
                if shape_valid:  # only warn once
                    msg = "image {} shape not valid, input:{}, require:{}".format(
                        file_name, img.shape, self.input_shape)
                    self.on_warning_message(msg)
                if not self.allow_reshape:
                    # not allow reshape, drop this image
                    continue
                img, y_bboxes = self._reshape_bbox(img, self.input_shape,
                                                   y_bboxes)
            datasets_x.append(img)
            datasets_y.append(y_bboxes)
        return True, "ok", labels, classes_data_counts, datasets_x, datasets_y

    def _load_datasets_pascal_voc(self, datasets_dir):
        '''
            load tfrecord, param and return the same as _load_datasets's
        '''
        from parse_pascal_voc_xml import decode_pascal_voc_xml
        from PIL import Image
        labels = []
        datasets_x = []
        datasets_y = []

        img_dir = os.path.join(datasets_dir, "images")
        ann_dir = os.path.join(datasets_dir, "xml")
        labels_path = os.path.join(datasets_dir, "labels.txt")

        # get labels from labels.txt
        labels = []
        with open(labels_path) as f:
            c = f.read()
            labels = c.split()
        # check labels
        ok, msg = self._is_labels_valid(labels)
        if not ok:
            return False, msg, [], None, None, None
        labels_len = len(labels)
        if labels_len < 1:
            return False, 'no classes find', [], None, None, None
        if labels_len > self.config_max_classes_limit:
            return False, 'classes too much, limit:{}, datasets:{}'.format(
                self.config_max_classes_limit,
                len(labels)), [], None, None, None
        classes_data_counts = [0] * labels_len
        # get xml path
        xmls = []
        for name in os.listdir(ann_dir):
            # print("--", name)
            if name.endswith(".xml"):
                xmls.append(os.path.join(ann_dir, name))
                continue
            if os.path.isdir(os.path.join(ann_dir, name)):
                for sub_name in os.listdir(os.path.join(ann_dir, name)):
                    if sub_name.endswith(".xml"):
                        path = os.path.join(ann_dir, name, sub_name)
                        xmls.append(path)
        # decode xml
        input_shape_checked = False
        for xml_path in xmls:
            with open(xml_path) as f:
                xml = f.read()
                ok, result = decode_pascal_voc_xml(xml)
                if not ok:
                    result = f"decode xml {xml_path} fail, reason: {result}"
                    self.on_warning_message(result)
                    continue
                # shape
                img_shape = (result['height'], result['width'],
                             result['depth'])
                #  check first image shape, and switch to proper supported input_shape
                if not input_shape_checked:
                    if not self._check_update_input_shape(
                            img_shape) and not self.allow_reshape:
                        return False, "not supported input size, supported: {}".format(
                            self.support_shapes), [], None, None, None
                    input_shape_checked = True

                need_to_reshape = False
                if img_shape != self.input_shape:
                    msg = f"decode xml {xml_path} ok, but shape {img_shape} not the same as expected: {self.input_shape}"
                    if not self.allow_reshape:
                        self.on_warning_message(msg)
                        continue
                    else:
                        need_to_reshape = True

                # load image
                dir_name = os.path.split(os.path.split(
                    result['path'])[0])[-1]  # class1 / images
                # images/class1/tututututut.jpg
                _, name = os.path.split(xml_path)
                name, ext = os.path.splitext(name)
                img_path = os.path.join(img_dir, name)
                found = False
                if os.path.exists(img_path + ".jpg"):
                    img_path = img_path + ".jpg"
                    found = True
                elif os.path.exists(img_path + ".png"):
                    img_path = img_path + ".png"
                    found = True

                if found:
                    if not need_to_reshape:
                        img = np.array(Image.open(img_path), dtype='uint8')
                    else:
                        img = np.array(Image.open(img_path).resize(
                            [self.input_shape[1], self.input_shape[0]],
                            Image.NEAREST),
                                       dtype='uint8')
                else:
                    # images/tututututut.jpg
                    img_path = os.path.join(img_dir, result['filename'])
                    if os.path.exists(img_path):
                        if not need_to_reshape:
                            img = np.array(Image.open(img_path), dtype='uint8')
                        else:
                            #Image.resize([width, height])
                            img = np.array(Image.open(img_path).resize(
                                [self.input_shape[1], self.input_shape[0]],
                                Image.NEAREST),
                                           dtype='uint8')
                    else:
                        result = f"decode xml {xml_path}, can not find iamge: {result['path']}"
                        self.on_warning_message(result)
                        continue
                # load bndboxes
                y = []
                for bbox in result['bboxes']:
                    if not bbox[4] in labels:
                        result = f"decode xml {xml_path}, can not find iamge: {result['path']}"
                        self.on_warning_message(result)
                        continue
                    label_idx = labels.index(bbox[4])
                    bbox[4] = label_idx  # replace label text with label index
                    classes_data_counts[label_idx] += 1
                    # range to [0, 1]
                    y.append(bbox[:5])
                if len(y) < 1:
                    result = f"decode xml {xml_path}, no object, skip"
                    self.on_warning_message(result)
                    continue
                if need_to_reshape:
                    y = self._reshape_bbox(img_shape, self.input_shape, y)
                datasets_x.append(img)
                datasets_y.append(y)
        return True, "ok", labels, classes_data_counts, datasets_x, datasets_y

    def _decode_pbtxt_file(self, file_path):
        '''
            @return list, if error, will raise Exception
        '''
        res = []
        with open(file_path) as f:
            content = f.read()
            items = re.findall("id: ([0-9].?)\n.*name: '(.*)'", content,
                               re.MULTILINE)
            for i, item in enumerate(items):
                id = int(item[0])
                name = item[1]
                if i != id - 1:
                    raise Exception(
                        f"datasets pbtxt file error, label:{name}'s id should be {i+1}, but now {id}, don't manually edit pbtxt file"
                    )
                res.append(name)
        return res

    def on_warning_message(self, msg):
        self.log.w(msg)
        self.warning_msg.append(msg)

    def _is_labels_valid(self, labels):
        '''
            labels len should >= 1
            and should be ascii letters, no Chinese or special words
        '''
        if len(labels) < 1:
            err_msg = "labels error: datasets no enough class"
            return False, err_msg
        if len(labels) > self.config_max_classes_limit:
            err_msg = "labels error: too much classes, now {}, but only support {}".format(
                len(labels), self.config_max_classes_limit)
            return False, err_msg
        for label in labels:
            if not isascii(label):
                return False, "labels error: class name(label) should not contain special letters"
        return True, "ok"

    def _is_datasets_valid(self,
                           labels,
                           classes_dataset_count,
                           one_class_min_images_num=100,
                           one_class_max_images_num=2000):
        '''
            dataset number in every label should > one_class_min_images_num and < one_class_max_images_num
        '''
        for i, label in enumerate(labels):
            # check image number
            if classes_dataset_count[i] < one_class_min_images_num:
                return False, "no enough train images in one class, '{}' only have {}, should > {}, now all datasets num({})".format(
                    label, classes_dataset_count[i], one_class_min_images_num,
                    sum(classes_dataset_count))
            if classes_dataset_count[i] > one_class_max_images_num:
                return False, "too many train images in one class, '{}' have {}, should < {}, now all datasets num({})".format(
                    label, classes_dataset_count[i], one_class_max_images_num,
                    sum(classes_dataset_count))
        return True, "ok"

    def _reshape_bbox(self, origin_shape, to_shape, bboxes):
        new_bboxes = []
        # print(origin_shape)
        for bbox in bboxes:
            new_bbox = [
                bbox[0] * to_shape[1] / origin_shape[1],
                bbox[1] * to_shape[0] / origin_shape[0],
                bbox[2] * to_shape[1] / origin_shape[1],
                bbox[3] * to_shape[0] / origin_shape[0], bbox[4]
            ]
            new_bboxes.append(new_bbox)
        # print(new_bboxes)
        return new_bboxes
Beispiel #3
0
    def __init__(self,
                 input_shape=(224, 224, 3),
                 datasets_dir=None,
                 datasets_zip=None,
                 unpack_dir=None,
                 logger=None,
                 max_classes_num=15,
                 min_images_num=40,
                 max_images_num=2000,
                 allow_reshape=False):
        '''
            input_shape: input shape (height, width)
            min_images_num: min image number in one class
        '''
        # import_libs() # 针对多进程
        import tensorflow as tf
        self.input_shape = input_shape
        self.need_rm_datasets = False
        self.datasets_rm_dir = None
        self.model = None
        self.history = None
        self.warning_msg = []  # append warning message here
        if logger:
            self.log = logger
        else:
            self.log = Fake_Logger()
        # unzip datasets
        if datasets_zip:
            self.datasets_dir = self._unpack_datasets(datasets_zip, unpack_dir)
            if not self.datasets_dir:
                self.log.e("can't detect datasets, check zip format")
                raise Exception("can't detect datasets, check zip format")
        elif datasets_dir:
            self.datasets_dir = datasets_dir
        else:
            self.log.e("no datasets args")
            raise Exception("no datasets args")
        # get labels by directory name
        self.labels = self._get_labels(self.datasets_dir)
        # check label
        ok, err_msg = self._is_label_data_valid(
            self.labels,
            max_classes_num=max_classes_num,
            min_images_num=min_images_num,
            max_images_num=max_images_num)
        if not ok:
            self.log.e(err_msg)
            raise Exception(err_msg)
        # check datasets format
        ok, err_msg = self._is_datasets_shape_valid(self.datasets_dir,
                                                    self.input_shape)
        if not ok:
            if not allow_reshape:
                self.log.e(err_msg)
                raise Exception(err_msg)
            self.on_warning_message(err_msg)

        class _Train_progress_cb(tf.keras.callbacks.Callback):  #剩余训练时间回调
            def __init__(self, epochs, user_progress_callback, logger):
                self.epochs = epochs
                self.logger = logger
                self.user_progress_callback = user_progress_callback

            def on_epoch_begin(self, epoch, logs=None):
                self.logger.i("epoch {} start".format(epoch))

            def on_epoch_end(self, epoch, logs=None):
                self.logger.i("epoch {} end: {}".format(epoch, logs))
                if self.user_progress_callback:
                    self.user_progress_callback(
                        (epoch + 1) / self.epochs * 100, "train epoch end")

            def on_train_begin(self, logs=None):
                self.logger.i("train start")
                if self.user_progress_callback:
                    self.user_progress_callback(0, "train start")

            def on_train_end(self, logs=None):
                self.logger.i("train end")
                if self.user_progress_callback:
                    self.user_progress_callback(100, "train end")

        self.Train_progress_cb = _Train_progress_cb
Beispiel #4
0
    def __init__(self,
                 input_shape=(224, 224, 3),
                 datasets_dir=None,
                 datasets_zip=None,
                 unpack_dir=None,
                 logger=None,
                 max_classes_limit=15,
                 one_class_min_images_num=100,
                 one_class_max_images_num=2000,
                 allow_reshape=False,
                 support_shapes=((224, 224, 3), (240, 240, 3))):
        '''
            input_shape: input shape (height, width)
            min_images_num: min image number in one class
        '''
        import tensorflow as tf  # for multiple process
        self.tf = tf
        self.need_rm_datasets = False
        self.input_shape = input_shape
        self.support_shapes = support_shapes
        if not self.input_shape in self.support_shapes:
            raise Exception(
                "input shape {} not support, only support: {}".format(
                    self.input_shape, self.support_shapes))
        self.allow_reshape = allow_reshape  # if dataset image's shape not the same as require's, reshape it
        self.config_max_classes_limit = max_classes_limit
        self.config_one_class_min_images_num = one_class_min_images_num
        self.config_one_class_max_images_num = one_class_max_images_num
        self.datasets_rm_dir = None
        self.model = None
        self.history = None
        self.warning_msg = []  # append warning message here
        if logger:
            self.log = logger
        else:
            self.log = Fake_Logger()
        # unzip datasets
        if datasets_zip:
            self.datasets_dir = self._unpack_datasets(datasets_zip, unpack_dir)
            if not self.datasets_dir:
                self.log.e("can't detect datasets, check zip format")
                raise Exception("can't detect datasets, check zip format")
        elif datasets_dir:
            self.datasets_dir = datasets_dir
        else:
            self.log.e("no datasets args")
            raise Exception("no datasets args")
        # parse datasets
        ok, msg, self.labels, classes_data_counts, datasets_x, datasets_y = self._load_datasets(
            self.datasets_dir)
        if not ok:
            msg = f"datasets format error: {msg}"
            self.log.e(msg)
            raise Exception(msg)
        # check datasets
        ok, err_msg = self._is_datasets_valid(
            self.labels,
            classes_data_counts,
            one_class_min_images_num=self.config_one_class_min_images_num,
            one_class_max_images_num=self.config_one_class_max_images_num)
        if not ok:
            self.log.e(err_msg)
            raise Exception(err_msg)
        self.log.i(
            "load datasets complete, check pass, images num:{}, bboxes num:{}".
            format(len(datasets_x), sum(classes_data_counts)))
        self.datasets_x = np.array(datasets_x, dtype='uint8')
        self.datasets_y = datasets_y

        class _Train_progress_cb(tf.keras.callbacks.Callback):  #剩余训练时间回调
            def __init__(self, epochs, user_progress_callback, logger):
                self.epochs = epochs
                self.logger = logger
                self.user_progress_callback = user_progress_callback

            def on_epoch_begin(self, epoch, logs=None):
                self.logger.i("epoch {} start".format(epoch))

            def on_epoch_end(self, epoch, logs=None):
                self.logger.i("epoch {} end: {}".format(epoch, logs))
                if self.user_progress_callback:
                    self.user_progress_callback(
                        (epoch + 1) / self.epochs * 100, "train epoch end")

            def on_train_begin(self, logs=None):
                self.logger.i("train start")
                if self.user_progress_callback:
                    self.user_progress_callback(0, "train start")

            def on_train_end(self, logs=None):
                self.logger.i("train end")
                if self.user_progress_callback:
                    self.user_progress_callback(100, "train end")

        self.Train_progress_cb = _Train_progress_cb
Beispiel #5
0
class Classifier(Train_Base):
    def __init__(self,
                 input_shape=(224, 224, 3),
                 datasets_dir=None,
                 datasets_zip=None,
                 unpack_dir=None,
                 logger=None,
                 max_classes_num=15,
                 min_images_num=40,
                 max_images_num=2000,
                 allow_reshape=False):
        '''
            input_shape: input shape (height, width)
            min_images_num: min image number in one class
        '''
        # import_libs() # 针对多进程
        import tensorflow as tf
        self.input_shape = input_shape
        self.need_rm_datasets = False
        self.datasets_rm_dir = None
        self.model = None
        self.history = None
        self.warning_msg = []  # append warning message here
        if logger:
            self.log = logger
        else:
            self.log = Fake_Logger()
        # unzip datasets
        if datasets_zip:
            self.datasets_dir = self._unpack_datasets(datasets_zip, unpack_dir)
            if not self.datasets_dir:
                self.log.e("can't detect datasets, check zip format")
                raise Exception("can't detect datasets, check zip format")
        elif datasets_dir:
            self.datasets_dir = datasets_dir
        else:
            self.log.e("no datasets args")
            raise Exception("no datasets args")
        # get labels by directory name
        self.labels = self._get_labels(self.datasets_dir)
        # check label
        ok, err_msg = self._is_label_data_valid(
            self.labels,
            max_classes_num=max_classes_num,
            min_images_num=min_images_num,
            max_images_num=max_images_num)
        if not ok:
            self.log.e(err_msg)
            raise Exception(err_msg)
        # check datasets format
        ok, err_msg = self._is_datasets_shape_valid(self.datasets_dir,
                                                    self.input_shape)
        if not ok:
            if not allow_reshape:
                self.log.e(err_msg)
                raise Exception(err_msg)
            self.on_warning_message(err_msg)

        class _Train_progress_cb(tf.keras.callbacks.Callback):  #剩余训练时间回调
            def __init__(self, epochs, user_progress_callback, logger):
                self.epochs = epochs
                self.logger = logger
                self.user_progress_callback = user_progress_callback

            def on_epoch_begin(self, epoch, logs=None):
                self.logger.i("epoch {} start".format(epoch))

            def on_epoch_end(self, epoch, logs=None):
                self.logger.i("epoch {} end: {}".format(epoch, logs))
                if self.user_progress_callback:
                    self.user_progress_callback(
                        (epoch + 1) / self.epochs * 100, "train epoch end")

            def on_train_begin(self, logs=None):
                self.logger.i("train start")
                if self.user_progress_callback:
                    self.user_progress_callback(0, "train start")

            def on_train_end(self, logs=None):
                self.logger.i("train end")
                if self.user_progress_callback:
                    self.user_progress_callback(100, "train end")

        self.Train_progress_cb = _Train_progress_cb

    def __del__(self):
        if self.need_rm_datasets:
            try:
                shutil.rmtree(self.datasets_dir)
            except Exception as e:
                try:
                    self.log.e("clean temp files error:{}".format(e))
                except Exception:
                    print("log object invalid")

    def train(self,
              epochs=100,
              progress_cb=None,
              weights=os.path.join(curr_file_dir, "weights",
                                   "mobilenet_7_5_224_tf_no_top.h5"),
              batch_size=5):
        self.log.i("train, labels:{}".format(self.labels))
        self.log.d("train, datasets dir:{}".format(self.datasets_dir))

        from mobilenet_sipeed import mobilenet
        import tensorflow as tf

        # pooling='avg', use around padding instead padding bottom and right for k210
        base_model = mobilenet.MobileNet0(input_shape=self.input_shape,
                                          alpha=0.75,
                                          depth_multiplier=1,
                                          dropout=0.001,
                                          pooling='avg',
                                          weights=weights,
                                          include_top=False)
        # update top layer
        out = base_model.output
        out = tf.keras.layers.Dropout(0.001, name='dropout')(out)
        preds = tf.keras.layers.Dense(len(self.labels),
                                      activation='softmax')(out)
        self.model = tf.keras.models.Model(inputs=base_model.input,
                                           outputs=preds)
        # only train top layers
        for layer in self.model.layers[:86]:
            layer.trainable = False
        for layer in self.model.layers[86:]:
            layer.trainable = True
        # #model.compile(loss=tf.keras.losses.categorical_crossentropy,optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),metrics=['accuracy'])
        self.model.compile(optimizer=tf.keras.optimizers.SGD(lr=1e-3),
                           loss='sparse_categorical_crossentropy',
                           metrics=['accuracy'])
        # #model.compile(optimizer=tf.compat.v1.train.RMSPropOptimizer(learning_rate=1e-3), loss='categorical_crossentropy',metrics=['accuracy'])
        # print model summary
        self.model.summary()

        # train
        # datasets process
        from tensorflow.keras.preprocessing.image import ImageDataGenerator
        from tensorflow.keras.applications.mobilenet import preprocess_input
        train_gen = ImageDataGenerator(preprocessing_function=preprocess_input,
                                       rotation_range=180,
                                       featurewise_center=True,
                                       featurewise_std_normalization=True,
                                       width_shift_range=0.2,
                                       height_shift_range=0.2,
                                       zoom_range=0.5,
                                       shear_range=0.5,
                                       validation_split=0.2)

        train_data = train_gen.flow_from_directory(
            self.datasets_dir,
            target_size=(self.input_shape[0], self.input_shape[1]),
            color_mode='rgb',
            batch_size=batch_size,
            class_mode='sparse',  # None / sparse / binary / categorical
            shuffle=True,
            subset="training")
        valid_data = train_gen.flow_from_directory(
            self.datasets_dir,
            target_size=(self.input_shape[0], self.input_shape[1]),
            color_mode='rgb',
            batch_size=batch_size,
            class_mode='sparse',
            shuffle=False,
            subset="validation")
        self.log.i("train data:{}, valid data:{}".format(
            train_data.samples, valid_data.samples))
        callbacks = [self.Train_progress_cb(epochs, progress_cb, self.log)]
        self.history = self.model.fit_generator(
            train_data,
            validation_data=valid_data,
            steps_per_epoch=train_data.samples // batch_size,
            validation_steps=valid_data.samples // batch_size,
            epochs=epochs,
            callbacks=callbacks)

    def report(self, out_path, limit_y_range=None):
        '''
            generate result charts
        '''
        self.log.i("generate report image")
        if not self.history:
            return
        history = self.history

        # set for server with no Tkagg GUI support, use agg(non-GUI backend)
        plt.switch_backend('agg')

        fig, axes = plt.subplots(3,
                                 1,
                                 constrained_layout=True,
                                 figsize=(10, 16),
                                 dpi=100)
        if limit_y_range:
            plt.ylim(limit_y_range)

        # acc and val_acc
        # {'loss': [0.5860330664989357, 0.3398533443955177], 'accuracy': [0.70944744, 0.85026735], 'val_loss': [0.4948340670338699, 0.49342870752194096], 'val_accuracy': [0.7, 0.74285716]}
        if "acc" in history.history:
            kws = {
                "acc": "acc",
                "val_acc": "val_acc",
                "loss": "loss",
                "val_loss": "val_loss"
            }
        else:
            kws = {
                "acc": "accuracy",
                "val_acc": "val_accuracy",
                "loss": "loss",
                "val_loss": "val_loss"
            }
        axes[0].plot(history.history[kws['acc']],
                     color='#2886EA',
                     label="train")
        axes[0].plot(history.history[kws['val_acc']],
                     color='#3FCD6D',
                     label="valid")
        axes[0].set_title('model accuracy')
        axes[0].set_ylabel('accuracy')
        axes[0].set_xlabel('epoch')
        axes[0].locator_params(integer=True)
        axes[0].legend()

        # loss and val_loss
        axes[1].plot(history.history[kws['loss']],
                     color='#2886EA',
                     label="train")
        axes[1].plot(history.history[kws['val_loss']],
                     color='#3FCD6D',
                     label="valid")
        axes[1].set_title('model loss')
        axes[1].set_ylabel('loss')
        axes[1].set_xlabel('epoch')
        axes[1].locator_params(integer=True)
        axes[1].legend()

        # confusion matrix
        cm, labels_idx = self._get_confusion_matrix()
        axes[2].imshow(cm, interpolation='nearest', cmap=plt.cm.GnBu)
        axes[2].set_title("confusion matrix")
        # axes[2].colorbar()
        num_local = np.array(range(len(labels_idx)))
        axes[2].set_xticks(num_local)
        axes[2].set_xticklabels(labels_idx.keys(), rotation=45)
        axes[2].set_yticks(num_local)
        axes[2].set_yticklabels(labels_idx.keys())

        thresh = cm.max(
        ) / 2.  # front color black or white according to the background color
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            axes[2].text(j,
                         i,
                         format(cm[i, j], 'd'),
                         horizontalalignment='center',
                         color='white' if cm[i, j] > thresh else "black")
        axes[2].set_ylabel('True label')
        axes[2].set_xlabel('Predicted label')

        # save to fs
        fig.savefig(out_path)
        plt.close()
        self.log.i("generate report image end")

    def save(self, h5_path=None, tflite_path=None):
        if h5_path:
            self.log.i("save model as .h5 file")
            if not h5_path.endswith(".h5"):
                if os.path.isdir(h5_path):
                    h5_path = os.path.join(h5_path, "classifier.h5")
                else:
                    h5_path += ".h5"
            if not self.model:
                raise Exception("no model defined")
            self.model.save(h5_path)
        if tflite_path:
            self.log.i("save model as .tflite file")
            if not tflite_path.endswith(".tflite"):
                if os.path.isdir(tflite_path):
                    tflite_path = os.path.join(tflite_path,
                                               "classifier.tflite")
                else:
                    tflite_path += ".tflite"
            import tensorflow as tf
            converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
            tflite_model = converter.convert()
            with open(tflite_path, "wb") as f:
                f.write(tflite_model)

    def infer(self, input):
        pass

    def get_sample_images(self, sample_num, copy_to_dir):
        if not self.datasets_dir or not os.path.exists(self.datasets_dir):
            raise Exception("datasets dir not exists")
        num_gen = self._get_sample_num(len(self.labels), sample_num)
        for label in self.labels:
            num = num_gen.__next__()
            images = os.listdir(os.path.join(self.datasets_dir, label))
            images = random.sample(images, num)
            for image in images:
                shutil.copyfile(os.path.join(self.datasets_dir, label, image),
                                os.path.join(copy_to_dir, image))

    def _get_confusion_matrix(self, ):
        batch_size = 5
        from tensorflow.keras.preprocessing.image import ImageDataGenerator
        from tensorflow.keras.applications.mobilenet import preprocess_input
        valid_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
        valid_data = valid_gen.flow_from_directory(
            self.datasets_dir,
            target_size=[self.input_shape[0], self.input_shape[1]],
            color_mode='rgb',
            batch_size=batch_size,
            class_mode='sparse',
            shuffle=False)
        prediction = self.model.predict_generator(valid_data,
                                                  steps=valid_data.samples //
                                                  batch_size,
                                                  verbose=1)
        predict_labels = np.argmax(prediction, axis=1)
        true_labels = valid_data.classes
        if len(predict_labels) != len(true_labels):
            true_labels = true_labels[0:len(predict_labels)]
        cm = confusion_matrix(true_labels, predict_labels)
        return cm, valid_data.class_indices

    def _unpack_datasets(self,
                         datasets_zip,
                         datasets_dir=None,
                         rm_dataset=True):
        '''
            uppack zip datasets to /temp, make /temp as tmpfs is recommend
            zip should be: 
                            datasets
                                   |
                                    ---- class1
                                   |
                                    ---- class2
            or: 
                        ---- class1
                        ---- class2
        '''
        if not datasets_dir:
            datasets_dir = os.path.join(tempfile.gettempdir(),
                                        "classifer_datasets")
            if rm_dataset:
                self.datasets_rm_dir = datasets_dir
                self.need_rm_datasets = True
        if not os.path.exists(datasets_dir):
            os.makedirs(datasets_dir)
        zip_file = zipfile.ZipFile(datasets_zip)
        for names in zip_file.namelist():
            zip_file.extract(names, datasets_dir)
        zip_file.close()
        dirs = []
        for d in os.listdir(datasets_dir):
            if d.startswith(".") or not os.path.isdir(
                    os.path.join(datasets_dir, d)):
                continue
            dirs.append(d)
        if len(dirs) == 1:  # sub dir
            root_dir = dirs[0]
            datasets_dir = os.path.join(datasets_dir, root_dir)
        elif len(dirs) > 1:
            pass
        else:  # empty zip
            return None
        return datasets_dir

    def _get_labels(self, datasets_dir):
        labels = []
        for d in os.listdir(datasets_dir):
            if d.startswith(".") or d == "__pycache__":
                continue
            if os.path.isdir(os.path.join(datasets_dir, d)):
                labels.append(d)
        return labels

    def _is_label_data_valid(self,
                             labels,
                             max_classes_num=15,
                             min_images_num=40,
                             max_images_num=2000):
        '''
            labels len should >= 2
            and should be ascii letters, no Chinese or special words
            images number in every label should > 40
        '''
        if len(labels) <= 1:
            err_msg = "datasets no enough class or directory error"
            return False, err_msg
        if len(labels) > max_classes_num:
            err_msg = "datasets too much class or directory error, limit:{} classses".format(
                max_classes_num)
            return False, err_msg
        print(labels, "---------")
        for label in labels:
            if not isascii(label):
                return False, "class name(label) should not contain special letters"
            # check image number
            files = os.listdir(os.path.join(self.datasets_dir, label))
            if len(files) < min_images_num:
                return False, "no enough train images in one class, should > {}".format(
                    min_images_num)
            if len(files) > max_images_num:
                return False, "too many train images in one class, should < {}".format(
                    max_images_num)
        return True, ""

    def _is_datasets_shape_valid(self, datasets_dir, shape):
        from PIL import Image
        ok = True
        msg = ""
        num_gen = self._get_sample_num(len(self.labels), len(self.labels))
        for label in self.labels:
            num = num_gen.__next__()
            images = os.listdir(os.path.join(self.datasets_dir, label))
            images = random.sample(images, num)
            for image in images:
                path = os.path.join(self.datasets_dir, label, image)
                img = np.array(Image.open(path))
                if img.shape != shape:
                    msg += f"image {label}/{image} shape is {img.shape}, but require {shape}\n"
                    ok = False
        return ok, msg

    def on_warning_message(self, msg):
        self.log.w(msg)
        self.warning_msg.append(msg)