Ejemplo n.º 1
0
def test_main(datasets_zip, model_path, report_path, log, use_cpu=False):
    import os
    curr_file_dir = os.path.abspath(os.path.dirname(__file__))
    if not os.path.exists("out"):
        os.makedirs("out")
    try:
        gpu = gpu_utils.select_gpu(memory_require=1 * 1024 * 1024 * 1024,
                                   tf_gpu_mem_growth=False)
    except Exception:
        gpu = None
    if gpu is None:
        if not use_cpu:
            log.e("no free GPU")
            return 1
        log.i("no GPU, will use [CPU]")
    else:
        log.i("select", gpu)
    detector = Detector(input_shape=(224, 224, 3),
                        datasets_zip=datasets_zip,
                        logger=log,
                        one_class_min_images_num=2)
    detector.train(
        epochs=2,
        progress_cb=train_on_progress,
        weights=os.path.abspath(
            f"{curr_file_dir}/weights/mobilenet_7_5_224_tf_no_top.h5"),
        save_best_weights_path="out/best_weights.h5",
        save_final_weights_path="out/final_weights.h5",
    )
    detector.report(report_path)
    detector.save(tflite_path="out/best_weights.tflite")
    detector.get_sample_images(5, "out/sample_images")
    print("--------result---------")
    print("anchors: {}".format(detector.anchors))
    print("labels:{}".format(detector.labels))
    print("-----------------------")
    if len(detector.warning_msg) > 0:
        print("---------------------")
        print("warining messages:")
        for msg in detector.warning_msg:
            print(msg)
        print("---------------------")
Ejemplo n.º 2
0
def test_main(datasets_zip, model_path, report_path, use_cpu=False):
    if not os.path.exists("out"):
        os.makedirs("out")
    log = Logger(file_path="out/train.log")
    try:
        gpu = gpu_utils.select_gpu(memory_require=1 * 1024 * 1024 * 1024,
                                   tf_gpu_mem_growth=False)
    except Exception:
        gpu = None
    if gpu is None:
        if not use_cpu:
            log.e("no free GPU")
            return 1
        log.i("no GPU, will use [CPU]")
    else:
        log.i("select", gpu)
    classifier = Classifier(datasets_zip=datasets_zip, logger=log)
    classifier.train(epochs=2, progress_cb=train_on_progress)
    classifier.report(report_path)
    classifier.save(model_path)
Ejemplo n.º 3
0
    def detector_train(self, log):
                # 检测 GPU 可用,选择一个可用的 GPU 使用
        try:
            gpu = gpu_utils.select_gpu(memory_require = config.detector_train_gpu_mem_require, tf_gpu_mem_growth=False)
        except Exception:
            gpu = None
        if gpu is None:
            if not config.allow_cpu:
                log.e("no free GPU")
                raise Exception((TrainFailReason.ERROR_NODE_BUSY, "node no enough GPU or GPU memory and not support CPU train"))
            log.i("no GPU, will use [CPU]")
        else:
            log.i("select", gpu)

        # 启动训练
        try:
            detector = Detector(input_shape=(224, 224, 3),
                                datasets_zip=self.datasets_zip_path,
                                datasets_dir=self.datasets_dir,
                                unpack_dir = self.temp_datasets_dir,
                                logger=log,
                                max_classes_limit = config.detector_train_max_classes_num,
                                one_class_min_images_num=config.detector_train_one_class_min_img_num,
                                one_class_max_images_num=config.detector_train_one_class_max_img_num,
                                allow_reshape = True)
        except Exception as e:
            log.e("train datasets not valid: {}".format(e))
            raise Exception((TrainFailReason.ERROR_PARAM, "datasets not valid: {}".format(str(e))))
        try:

            detector.train(epochs=config.detector_train_epochs,
                    progress_cb=self.__on_train_progress,
                    save_best_weights_path = self.best_h5_model_path,
                    save_final_weights_path = self.final_h5_model_path,
                    jitter=False,
                    is_only_detect = False,
                    batch_size = config.detector_train_batch_size,
                    train_times = 5,
                    valid_times = 2,
                    learning_rate=config.detector_train_learn_rate,
                )
        except Exception as e:
            log.e("train error: {}".format(e))
            traceback.print_exc()
            raise Exception((TrainFailReason.ERROR_INTERNAL, "error occurred when train, error: {}".format(str(e)) ))

        # 训练结束, 生成报告
        log.i("train ok, now generate report")
        detector.report(self.result_report_img_path)

        # 生成 kmodel
        log.i("now generate kmodel")
        detector.save(tflite_path=self.tflite_path)
        detector.get_sample_images(config.sample_image_num, self.dataset_sample_images_path)
        ok, msg = self.convert_to_kmodel(self.tflite_path, self.result_kmodel_path, config.ncc_kmodel_v3, self.dataset_sample_images_path)
        if not ok:
            log.e("convert to kmodel fail")
            raise Exception((TrainFailReason.ERROR_INTERNAL, "convert kmodel fail: {}".format(msg) ))

        # 拷贝模板文件
        log.i("copy template files")
        template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "detector", "template")
        self.__copy_template_files(template_dir, self.result_dir)

        # 写入 label 文件
        replace = 'labels = ["{}"]'.format('", "'.join(detector.labels))
        with open(self.result_labels_path, "w") as f:
            f.write(replace)
        with open(self.result_boot_py_path) as f:
            boot_py = f.read()
        with open(self.result_boot_py_path, "w") as f:
            target = 'labels = [] # labels'
            boot_py = boot_py.replace(target, replace)
            target = 'anchors = [] # anchors'
            replace = 'anchors = [{}]'.format(', '.join(str(i) for i in detector.anchors))
            boot_py = boot_py.replace(target, replace)
            target = 'sensor.set_windowing((224, 224))'
            replace = 'sensor.set_windowing(({}, {}))'.format(detector.input_shape[1], detector.input_shape[0])
            boot_py = boot_py.replace(target, replace)
            f.write(boot_py)

        return detector, config.detector_result_file_name_prefix