Beispiel #1
0
def export_onnx_model(model, save_file, opset_version=10):
    if model.__class__.__name__ == "FastSCNN" or (
            model.model_type == "detector"
            and model.__class__.__name__ != "YOLOv3"):
        logging.error(
            "Only image classifier models, detection models(YOLOv3) and semantic segmentation models(except FastSCNN) are supported to export to ONNX"
        )
    try:
        import paddle2onnx
    except:
        logging.error(
            "You need to install paddle2onnx first, pip install paddle2onnx==0.4"
        )

    import paddle2onnx as p2o

    if p2o.__version__ != '0.4':
        logging.error(
            "You need install paddle2onnx==0.4, but the version of paddle2onnx is {}"
            .format(p2o.__version__))

    if opset_version == 10 and model.__class__.__name__ == "YOLOv3":
        logging.warning(
            "Export for openVINO by default, the output of multiclass_nms exported to onnx will contains background. If you need onnx completely consistent with paddle, please use paddle2onnx to export"
        )

    p2o.register_op_mapper('multiclass_nms', MultiClassNMS4OpenVINO)

    p2o.program2onnx(model.test_prog,
                     scope=model.scope,
                     save_file=save_file,
                     opset_version=opset_version)
Beispiel #2
0
def coco_bbox_eval(results,
                   coco_gt,
                   with_background=True,
                   is_bbox_normalized=False):
    assert 'bbox' in results[0]
    # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
    # or matplotlib.backends is imported for the first time
    # pycocotools import matplotlib
    import matplotlib
    matplotlib.use('Agg')
    from pycocotools.coco import COCO

    cat_ids = coco_gt.getCatIds()

    # when with_background = True, mapping category to classid, like:
    #   background:0, first_class:1, second_class:2, ...
    clsid2catid = dict(
        {i + int(with_background): catid
         for i, catid in enumerate(cat_ids)})

    xywh_results = bbox2out(
        results, clsid2catid, is_bbox_normalized=is_bbox_normalized)

    results = copy.deepcopy(xywh_results)
    if len(xywh_results) == 0:
        logging.warning(
            "The number of valid bbox detected is zero.\n Please use reasonable model and check input data.\n stop eval!"
        )
        return [0.0], results

    map_stats = cocoapi_eval(xywh_results, 'bbox', coco_gt=coco_gt)
    # flush coco evaluation result
    sys.stdout.flush()
    return map_stats, results
Beispiel #3
0
 def set_num_samples(self, num_samples):
     if num_samples > len(self.file_list):
         logging.warning(
             "You want set num_samples to {}, but your dataset only has {} samples, so we will keep your dataset num_samples as {}"
             .format(num_samples, len(self.file_list), len(self.file_list)))
         num_samples = len(self.file_list)
     self.num_samples = num_samples
Beispiel #4
0
def mask_eval(results, coco_gt, resolution, thresh_binarize=0.5):
    assert 'mask' in results[0]
    from pycocotools.coco import COCO

    clsid2catid = {i + 1: v for i, v in enumerate(coco_gt.getCatIds())}

    segm_results = mask2out(results, clsid2catid, resolution, thresh_binarize)
    results = copy.deepcopy(segm_results)
    if len(segm_results) == 0:
        logging.warning(
            "The number of valid mask detected is zero.\n Please use reasonable model and check input data."
        )
        return None, results

    map_stats = cocoapi_eval(segm_results, 'segm', coco_gt=coco_gt)
    return map_stats, results
Beispiel #5
0
    def create_predictor(self,
                         use_gpu=True,
                         gpu_id=0,
                         use_mkl=False,
                         mkl_thread_num=4,
                         use_trt=False,
                         use_glog=False,
                         memory_optimize=True,
                         max_trt_batch_size=1):
        config = fluid.core.AnalysisConfig(
            os.path.join(self.model_dir, '__model__'),
            os.path.join(self.model_dir, '__params__'))

        if use_gpu:
            # 设置GPU初始显存(单位M)和Device ID
            config.enable_use_gpu(100, gpu_id)
            if use_trt:
                config.enable_tensorrt_engine(
                    workspace_size=1 << 10,
                    max_batch_size=max_trt_batch_size,
                    min_subgraph_size=3,
                    precision_mode=fluid.core.AnalysisConfig.Precision.Float32,
                    use_static=False,
                    use_calib_mode=False)
        else:
            config.disable_gpu()
        if use_mkl and not use_gpu:
            if self.model_name not in ["HRNet", "DeepLabv3p", "PPYOLO"]:
                config.enable_mkldnn()
                config.set_cpu_math_library_num_threads(mkl_thread_num)
            else:
                logging.warning(
                    "HRNet/DeepLabv3p/PPYOLO are not supported for the use of mkldnn\n"
                )
        if use_glog:
            config.enable_glog_info()
        else:
            config.disable_glog_info()
        if memory_optimize:
            config.enable_memory_optim()

        # 开启计算图分析优化,包括OP融合等
        config.switch_ir_optim(True)
        # 关闭feed和fetch OP使用,使用ZeroCopy接口必须设置此项
        config.switch_use_feed_fetch_ops(False)
        predictor = fluid.core.create_paddle_predictor(config)
        return predictor
Beispiel #6
0
def mask_eval(results, coco_gt, resolution, thresh_binarize=0.5):
    assert 'mask' in results[0]
    # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
    # or matplotlib.backends is imported for the first time
    # pycocotools import matplotlib
    import matplotlib
    matplotlib.use('Agg')
    from pycocotools.coco import COCO

    clsid2catid = {i + 1: v for i, v in enumerate(coco_gt.getCatIds())}

    segm_results = mask2out(results, clsid2catid, resolution, thresh_binarize)
    results = copy.deepcopy(segm_results)
    if len(segm_results) == 0:
        logging.warning(
            "The number of valid mask detected is zero.\n Please use reasonable model and check input data."
        )
        return None, results

    map_stats = cocoapi_eval(segm_results, 'segm', coco_gt=coco_gt)
    return map_stats, results
Beispiel #7
0
    def evaluate(self,
                 eval_dataset,
                 batch_size=1,
                 epoch_id=None,
                 metric=None,
                 return_details=False):
        """评估。

        Args:
            eval_dataset (paddlex.datasets): 验证数据读取器。
            batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。
            epoch_id (int): 当前评估模型所在的训练轮数。
            metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
                根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
                如为COCODetection,则metric为'COCO'。
            return_details (bool): 是否返回详细信息。默认值为False。

        Returns:
            tuple (metrics, eval_details) /dict (metrics): 当return_details为True时,返回(metrics, eval_details),
                当return_details为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,
                分别表示平均准确率平均值在各个阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。
                eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
                预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
        """
        self.arrange_transforms(transforms=eval_dataset.transforms,
                                mode='eval')
        if metric is None:
            if hasattr(self, 'metric') and self.metric is not None:
                metric = self.metric
            else:
                if isinstance(eval_dataset, paddlex.datasets.CocoDetection):
                    metric = 'COCO'
                elif isinstance(eval_dataset, paddlex.datasets.VOCDetection):
                    metric = 'VOC'
                else:
                    raise Exception(
                        "eval_dataset should be datasets.VOCDetection or datasets.COCODetection."
                    )
        assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
        if batch_size > 1:
            batch_size = 1
            logging.warning(
                "Faster RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1."
            )
        dataset = eval_dataset.generator(batch_size=batch_size,
                                         drop_last=False)

        total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
        results = list()
        logging.info(
            "Start to evaluating(total_samples={}, total_steps={})...".format(
                eval_dataset.num_samples, total_steps))
        for step, data in tqdm.tqdm(enumerate(dataset()), total=total_steps):
            images = np.array([d[0] for d in data]).astype('float32')
            im_infos = np.array([d[1] for d in data]).astype('float32')
            im_shapes = np.array([d[3] for d in data]).astype('float32')
            feed_data = {
                'image': images,
                'im_info': im_infos,
                'im_shape': im_shapes,
            }
            outputs = self.exe.run(self.test_prog,
                                   feed=[feed_data],
                                   fetch_list=list(self.test_outputs.values()),
                                   return_numpy=False)
            res = {
                'bbox':
                (np.array(outputs[0]), outputs[0].recursive_sequence_lengths())
            }
            res_im_id = [d[2] for d in data]
            res['im_info'] = (im_infos, [])
            res['im_shape'] = (im_shapes, [])
            res['im_id'] = (np.array(res_im_id), [])
            if metric == 'VOC':
                res_gt_box = []
                res_gt_label = []
                res_is_difficult = []
                for d in data:
                    res_gt_box.extend(d[4])
                    res_gt_label.extend(d[5])
                    res_is_difficult.extend(d[6])
                res_gt_box_lod = [d[4].shape[0] for d in data]
                res_gt_label_lod = [d[5].shape[0] for d in data]
                res_is_difficult_lod = [d[6].shape[0] for d in data]
                res['gt_box'] = (np.array(res_gt_box), [res_gt_box_lod])
                res['gt_label'] = (np.array(res_gt_label), [res_gt_label_lod])
                res['is_difficult'] = (np.array(res_is_difficult),
                                       [res_is_difficult_lod])
            results.append(res)
            logging.debug("[EVAL] Epoch={}, Step={}/{}".format(
                epoch_id, step + 1, total_steps))
        box_ap_stats, eval_details = eval_results(results,
                                                  metric,
                                                  eval_dataset.coco_gt,
                                                  with_background=True)
        metrics = OrderedDict(
            zip(['bbox_mmap' if metric == 'COCO' else 'bbox_map'],
                box_ap_stats))
        if return_details:
            return metrics, eval_details
        return metrics
Beispiel #8
0
 def net_initialize(self,
                    startup_prog=None,
                    pretrain_weights=None,
                    fuse_bn=False,
                    save_dir='.',
                    sensitivities_file=None,
                    eval_metric_loss=0.05,
                    resume_checkpoint=None):
     if not resume_checkpoint:
         pretrain_dir = osp.join(save_dir, 'pretrain')
         if not os.path.isdir(pretrain_dir):
             if os.path.exists(pretrain_dir):
                 os.remove(pretrain_dir)
             os.makedirs(pretrain_dir)
         if pretrain_weights is not None and not os.path.exists(
                 pretrain_weights):
             if self.model_type == 'classifier':
                 if pretrain_weights not in ['IMAGENET']:
                     logging.warning(
                         "Pretrain_weights for classifier should be defined as directory path or parameter file or 'IMAGENET' or None, but it is {}, so we force to set it as 'IMAGENET'"
                         .format(pretrain_weights))
                     pretrain_weights = 'IMAGENET'
             elif self.model_type == 'detector':
                 if pretrain_weights not in ['IMAGENET', 'COCO']:
                     logging.warning(
                         "Pretrain_weights for detector should be defined as directory path or parameter file or 'IMAGENET' or 'COCO' or None, but it is {}, so we force to set it as 'IMAGENET'"
                         .format(pretrain_weights))
                     pretrain_weights = 'IMAGENET'
             elif self.model_type == 'segmenter':
                 if pretrain_weights not in [
                         'IMAGENET', 'COCO', 'CITYSCAPES'
                 ]:
                     logging.warning(
                         "Pretrain_weights for segmenter should be defined as directory path or parameter file or 'IMAGENET' or 'COCO' or 'CITYSCAPES', but it is {}, so we force to set it as 'IMAGENET'"
                         .format(pretrain_weights))
                     pretrain_weights = 'IMAGENET'
         if hasattr(self, 'backbone'):
             backbone = self.backbone
         else:
             backbone = self.__class__.__name__
             if backbone == "HRNet":
                 backbone = backbone + "_W{}".format(self.width)
         class_name = self.__class__.__name__
         pretrain_weights = get_pretrain_weights(pretrain_weights,
                                                 class_name, backbone,
                                                 pretrain_dir)
     if startup_prog is None:
         startup_prog = fluid.default_startup_program()
     self.exe.run(startup_prog)
     if resume_checkpoint:
         logging.info(
             "Resume checkpoint from {}.".format(resume_checkpoint),
             use_color=True)
         paddlex.utils.utils.load_pretrain_weights(self.exe,
                                                   self.train_prog,
                                                   resume_checkpoint,
                                                   resume=True)
         if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
             raise Exception(
                 "There's not model.yml in {}".format(resume_checkpoint))
         with open(osp.join(resume_checkpoint, "model.yml")) as f:
             info = yaml.load(f.read(), Loader=yaml.Loader)
             self.completed_epochs = info['completed_epochs']
     elif pretrain_weights is not None:
         logging.info(
             "Load pretrain weights from {}.".format(pretrain_weights),
             use_color=True)
         paddlex.utils.utils.load_pretrain_weights(self.exe,
                                                   self.train_prog,
                                                   pretrain_weights,
                                                   fuse_bn)
     # 进行裁剪
     if sensitivities_file is not None:
         import paddleslim
         from .slim.prune_config import get_sensitivities
         sensitivities_file = get_sensitivities(sensitivities_file, self,
                                                save_dir)
         from .slim.prune import get_params_ratios, prune_program
         logging.info(
             "Start to prune program with eval_metric_loss = {}".format(
                 eval_metric_loss),
             use_color=True)
         origin_flops = paddleslim.analysis.flops(self.test_prog)
         prune_params_ratios = get_params_ratios(
             sensitivities_file, eval_metric_loss=eval_metric_loss)
         prune_program(self, prune_params_ratios)
         current_flops = paddleslim.analysis.flops(self.test_prog)
         remaining_ratio = current_flops / origin_flops
         logging.info(
             "Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
             .format(origin_flops, current_flops, remaining_ratio),
             use_color=True)
         self.status = 'Prune'
Beispiel #9
0
    def evaluate(self,
                 eval_dataset,
                 batch_size=1,
                 epoch_id=None,
                 metric=None,
                 return_details=False):
        """评估。

        Args:
            eval_dataset (paddlex.datasets): 验证数据读取器。
            batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。
            epoch_id (int): 当前评估模型所在的训练轮数。
            metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
                根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
                如为COCODetection,则metric为'COCO'。
            return_details (bool): 是否返回详细信息。默认值为False。

        Returns:
            tuple (metrics, eval_details) /dict (metrics): 当return_details为True时,返回(metrics, eval_details),
                当return_details为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'和'segm_mmap'
                或者’bbox_map‘和'segm_map',分别表示预测框和分割区域平均准确率平均值在
                各个IoU阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。eval_details为dict,
                包含bbox、mask和gt三个关键字。其中关键字bbox的键值是一个列表,列表中每个元素代表一个预测结果,
                一个预测结果是一个由图像id,预测框类别id, 预测框坐标,预测框得分组成的列表。
                关键字mask的键值是一个列表,列表中每个元素代表各预测框内物体的分割结果,分割结果由图像id、
                预测框类别id、表示预测框内各像素点是否属于物体的二值图、预测框得分。
                而关键字gt的键值是真实标注框的相关信息。
        """
        input_channel = getattr(self, 'input_channel', 3)
        arrange_transforms(model_type=self.model_type,
                           class_name=self.__class__.__name__,
                           transforms=eval_dataset.transforms,
                           mode='eval',
                           input_channel=input_channel)
        if metric is None:
            if hasattr(self, 'metric') and self.metric is not None:
                metric = self.metric
            else:
                if isinstance(eval_dataset, paddlex.datasets.CocoDetection):
                    metric = 'COCO'
                else:
                    raise Exception(
                        "eval_dataset should be datasets.COCODetection.")
        assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
        if batch_size > 1:
            batch_size = 1
            logging.warning(
                "Mask RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1."
            )
        data_generator = eval_dataset.generator(batch_size=batch_size,
                                                drop_last=False)

        total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
        results = list()
        logging.info(
            "Start to evaluating(total_samples={}, total_steps={})...".format(
                eval_dataset.num_samples, total_steps))
        for step, data in tqdm.tqdm(enumerate(data_generator()),
                                    total=total_steps):
            images = np.array([d[0] for d in data]).astype('float32')
            im_infos = np.array([d[1] for d in data]).astype('float32')
            im_shapes = np.array([d[3] for d in data]).astype('float32')
            feed_data = {
                'image': images,
                'im_info': im_infos,
                'im_shape': im_shapes,
            }
            with fluid.scope_guard(self.scope):
                outputs = self.exe.run(self.test_prog,
                                       feed=[feed_data],
                                       fetch_list=list(
                                           self.test_outputs.values()),
                                       return_numpy=False)
            res = {
                'bbox': (np.array(outputs[0]),
                         outputs[0].recursive_sequence_lengths()),
                'mask':
                (np.array(outputs[1]), outputs[1].recursive_sequence_lengths())
            }
            res_im_id = [d[2] for d in data]
            res['im_info'] = (im_infos, [])
            res['im_shape'] = (im_shapes, [])
            res['im_id'] = (np.array(res_im_id), [])
            results.append(res)
            logging.debug("[EVAL] Epoch={}, Step={}/{}".format(
                epoch_id, step + 1, total_steps))

        ap_stats, eval_details = eval_results(
            results,
            'COCO',
            eval_dataset.coco_gt,
            with_background=True,
            resolution=self.mask_head_resolution)
        if metric == 'VOC':
            if isinstance(ap_stats[0], np.ndarray) and isinstance(
                    ap_stats[1], np.ndarray):
                metrics = OrderedDict(
                    zip(['bbox_map', 'segm_map'],
                        [ap_stats[0][1], ap_stats[1][1]]))
            else:
                metrics = OrderedDict(zip(['bbox_map', 'segm_map'],
                                          [0.0, 0.0]))
        elif metric == 'COCO':
            if isinstance(ap_stats[0], np.ndarray) and isinstance(
                    ap_stats[1], np.ndarray):
                metrics = OrderedDict(
                    zip(['bbox_mmap', 'segm_mmap'],
                        [ap_stats[0][0], ap_stats[1][0]]))
            else:
                metrics = OrderedDict(
                    zip(['bbox_mmap', 'segm_mmap'], [0.0, 0.0]))
        if return_details:
            return metrics, eval_details
        return metrics
Beispiel #10
0
def get_pretrain_weights(flag, class_name, backbone, save_dir):
    if flag is None:
        return None
    elif osp.isdir(flag):
        return flag
    elif osp.isfile(flag):
        return flag
    warning_info = "{} does not support to be finetuned with weights pretrained on the {} dataset, so pretrain_weights is forced to be set to {}"
    if flag == 'COCO':
        if class_name == "FasterRCNN" and backbone in ['ResNet18'] or \
            class_name == "MaskRCNN" and backbone in ['ResNet18', 'HRNet_W18'] or \
            class_name == 'DeepLabv3p' and backbone in ['Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']:
            model_name = '{}_{}'.format(class_name, backbone)
            logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
            flag = 'IMAGENET'
        elif class_name == 'HRNet':
            logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
            flag = 'IMAGENET'
        elif class_name == 'FastSCNN':
            logging.warning(
                warning_info.format(class_name, flag, 'CITYSCAPES'))
            flag = 'CITYSCAPES'
    elif flag == 'CITYSCAPES':
        model_name = '{}_{}'.format(class_name, backbone)
        if class_name == 'UNet':
            logging.warning(warning_info.format(class_name, flag, 'COCO'))
            flag = 'COCO'
        if class_name == 'HRNet' and backbone.split('_')[
                -1] in ['W30', 'W32', 'W40', 'W48', 'W60', 'W64']:
            logging.warning(warning_info.format(backbone, flag, 'IMAGENET'))
            flag = 'IMAGENET'
        if class_name == 'DeepLabv3p' and backbone in [
                'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
                'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
        ]:
            model_name = '{}_{}'.format(class_name, backbone)
            logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
            flag = 'IMAGENET'
    elif flag == 'IMAGENET':
        if class_name == 'UNet':
            logging.warning(warning_info.format(class_name, flag, 'COCO'))
            flag = 'COCO'
        elif class_name == 'FastSCNN':
            logging.warning(
                warning_info.format(class_name, flag, 'CITYSCAPES'))
            flag = 'CITYSCAPES'

    if flag == 'IMAGENET':
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        if backbone.startswith('Xception'):
            backbone = 'Seg{}'.format(backbone)
        elif backbone == 'MobileNetV2':
            backbone = 'MobileNetV2_x1.0'
        elif backbone == 'MobileNetV3_small_ssld':
            backbone = 'MobileNetV3_small_x1_0_ssld'
        elif backbone == 'MobileNetV3_large_ssld':
            backbone = 'MobileNetV3_large_x1_0_ssld'
        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
            if backbone == 'ResNet50':
                backbone = 'DetResNet50'
        assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
            backbone)

        #        if backbone == 'AlexNet':
        #            url = image_pretrain[backbone]
        #            fname = osp.split(url)[-1].split('.')[0]
        #            paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #            return osp.join(new_save_dir, fname)
        try:
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            if isinstance(e, hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(e, hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connecgtion"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
    elif flag in ['COCO', 'CITYSCAPES']:
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p']:
            backbone = '{}_{}'.format(class_name, backbone)
        backbone = "{}_{}".format(backbone, flag)
        if flag == 'COCO':
            url = coco_pretrain[backbone]
        elif flag == 'CITYSCAPES':
            url = cityscapes_pretrain[backbone]
        fname = osp.split(url)[-1].split('.')[0]
        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #        return osp.join(new_save_dir, fname)
        try:
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            if isinstance(hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connecgtion"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
    else:
        raise Exception(
            "pretrain_weights need to be defined as directory path or 'IMAGENET' or 'COCO' or 'Cityscapes' (download pretrain weights automatically)."
        )
Beispiel #11
0
    def opset_10(cls, graph, node, **kw):
        from paddle2onnx.constant import dtypes
        import numpy as np
        result_name = node.output('Out', 0)
        background = node.attr('background_label')
        normalized = node.attr('normalized')
        if normalized == False:
            logging.warning(
                        "The parameter normalized of multiclass_nms OP of Paddle is False, which has diff with ONNX." \
                        " Please set normalized=True in multiclass_nms of Paddle, see doc Q1 in" \
                        " https://github.com/PaddlePaddle/paddle2onnx/blob/develop/FAQ.md")

        #convert the paddle attribute to onnx tensor
        node_score_threshold = graph.make_node(
            'Constant',
            inputs=[],
            dtype=dtypes.ONNX.FLOAT,
            value=[float(node.attr('score_threshold'))])

        node_iou_threshold = graph.make_node(
            'Constant',
            inputs=[],
            dtype=dtypes.ONNX.FLOAT,
            value=[float(node.attr('nms_threshold'))])

        node_keep_top_k = graph.make_node(
            'Constant',
            inputs=[],
            dtype=dtypes.ONNX.INT64,
            value=[np.int64(node.attr('keep_top_k'))])

        node_keep_top_k_2D = graph.make_node('Constant',
                                             inputs=[],
                                             dtype=dtypes.ONNX.INT64,
                                             dims=[1, 1],
                                             value=[node.attr('keep_top_k')])

        # the paddle data format is x1,y1,x2,y2
        kwargs = {'center_point_box': 0}

        node_select_nms= graph.make_node(
            'NonMaxSuppression',
            inputs=[node.input('BBoxes', 0), node.input('Scores', 0), node_keep_top_k,\
                node_iou_threshold, node_score_threshold])

        # step 1 nodes select the nms class
        # create some const value to use
        node_const_value = [result_name+"@const_0",
            result_name+"@const_1",\
            result_name+"@const_2",\
            result_name+"@const_-1"]
        value_const_value = [0, 1, 2, -1]
        for name, value in zip(node_const_value, value_const_value):
            graph.make_node('Constant',
                            layer_name=name,
                            inputs=[],
                            outputs=[name],
                            dtype=dtypes.ONNX.INT64,
                            value=[value])

        # In this code block, we will deocde the raw score data, reshape N * C * M to 1 * N*C*M
        # and the same time, decode the select indices to 1 * D, gather the select_indices
        node_gather_1 = graph.make_node(
            'Gather',
            inputs=[node_select_nms, result_name + "@const_1"],
            axis=1)

        node_gather_1 = graph.make_node('Unsqueeze',
                                        inputs=[node_gather_1],
                                        axes=[0])

        node_gather_2 = graph.make_node(
            'Gather',
            inputs=[node_select_nms, result_name + "@const_2"],
            axis=1)

        node_gather_2 = graph.make_node('Unsqueeze',
                                        inputs=[node_gather_2],
                                        axes=[0])

        # reshape scores N * C * M to (N*C*M) * 1
        node_reshape_scores_rank1 = graph.make_node(
            "Reshape",
            inputs=[node.input('Scores', 0), result_name + "@const_-1"])

        # get the shape of scores
        node_shape_scores = graph.make_node('Shape',
                                            inputs=node.input('Scores'))

        # gather the index: 2 shape of scores
        node_gather_scores_dim1 = graph.make_node(
            'Gather',
            inputs=[node_shape_scores, result_name + "@const_2"],
            axis=0)

        # mul class * M
        node_mul_classnum_boxnum = graph.make_node(
            'Mul', inputs=[node_gather_1, node_gather_scores_dim1])

        # add class * M * index
        node_add_class_M_index = graph.make_node(
            'Add', inputs=[node_mul_classnum_boxnum, node_gather_2])

        # Squeeze the indices to 1 dim
        node_squeeze_select_index = graph.make_node(
            'Squeeze', inputs=[node_add_class_M_index], axes=[0, 2])

        # gather the data from flatten scores
        node_gather_select_scores = graph.make_node(
            'Gather',
            inputs=[node_reshape_scores_rank1, node_squeeze_select_index],
            axis=0)

        # get nums to input TopK
        node_shape_select_num = graph.make_node(
            'Shape', inputs=[node_gather_select_scores])

        node_gather_select_num = graph.make_node(
            'Gather',
            inputs=[node_shape_select_num, result_name + "@const_0"],
            axis=0)

        node_unsqueeze_select_num = graph.make_node(
            'Unsqueeze', inputs=[node_gather_select_num], axes=[0])

        node_concat_topK_select_num = graph.make_node(
            'Concat',
            inputs=[node_unsqueeze_select_num, node_keep_top_k_2D],
            axis=0)

        node_cast_concat_topK_select_num = graph.make_node(
            'Cast', inputs=[node_concat_topK_select_num], to=6)
        # get min(topK, num_select)
        node_compare_topk_num_select = graph.make_node(
            'ReduceMin', inputs=[node_cast_concat_topK_select_num], keepdims=0)

        # unsqueeze the indices to 1D tensor
        node_unsqueeze_topk_select_indices = graph.make_node(
            'Unsqueeze', inputs=[node_compare_topk_num_select], axes=[0])

        # cast the indices to INT64
        node_cast_topk_indices = graph.make_node(
            'Cast', inputs=[node_unsqueeze_topk_select_indices], to=7)

        # select topk scores  indices
        outputs_topk_select_topk_indices = [result_name + "@topk_select_topk_values",\
            result_name + "@topk_select_topk_indices"]
        node_topk_select_topk_indices = graph.make_node(
            'TopK',
            inputs=[node_gather_select_scores, node_cast_topk_indices],
            outputs=outputs_topk_select_topk_indices)

        # gather topk label, scores, boxes
        node_gather_topk_scores = graph.make_node(
            'Gather',
            inputs=[
                node_gather_select_scores, outputs_topk_select_topk_indices[1]
            ],
            axis=0)

        node_gather_topk_class = graph.make_node(
            'Gather',
            inputs=[node_gather_1, outputs_topk_select_topk_indices[1]],
            axis=1)

        # gather the boxes need to gather the boxes id, then get boxes
        node_gather_topk_boxes_id = graph.make_node(
            'Gather',
            inputs=[node_gather_2, outputs_topk_select_topk_indices[1]],
            axis=1)

        # squeeze the gather_topk_boxes_id to 1 dim
        node_squeeze_topk_boxes_id = graph.make_node(
            'Squeeze', inputs=[node_gather_topk_boxes_id], axes=[0, 2])

        node_gather_select_boxes = graph.make_node(
            'Gather',
            inputs=[node.input('BBoxes', 0), node_squeeze_topk_boxes_id],
            axis=1)

        # concat the final result
        # before concat need to cast the class to float
        node_cast_topk_class = graph.make_node('Cast',
                                               inputs=[node_gather_topk_class],
                                               to=1)

        node_unsqueeze_topk_scores = graph.make_node(
            'Unsqueeze', inputs=[node_gather_topk_scores], axes=[0, 2])

        inputs_concat_final_results = [node_cast_topk_class, node_unsqueeze_topk_scores, \
            node_gather_select_boxes]
        node_sort_by_socre_results = graph.make_node(
            'Concat', inputs=inputs_concat_final_results, axis=2)

        # select topk classes indices
        node_squeeze_cast_topk_class = graph.make_node(
            'Squeeze', inputs=[node_cast_topk_class], axes=[0, 2])
        node_neg_squeeze_cast_topk_class = graph.make_node(
            'Neg', inputs=[node_squeeze_cast_topk_class])

        outputs_topk_select_classes_indices = [result_name + "@topk_select_topk_classes_scores",\
            result_name + "@topk_select_topk_classes_indices"]
        node_topk_select_topk_indices = graph.make_node(
            'TopK',
            inputs=[node_neg_squeeze_cast_topk_class, node_cast_topk_indices],
            outputs=outputs_topk_select_classes_indices)
        node_concat_final_results = graph.make_node(
            'Gather',
            inputs=[
                node_sort_by_socre_results,
                outputs_topk_select_classes_indices[1]
            ],
            axis=1)
        node_concat_final_results = graph.make_node(
            'Squeeze',
            inputs=[node_concat_final_results],
            outputs=[node.output('Out', 0)],
            axes=[0])

        if node.type == 'multiclass_nms2':
            graph.make_node('Squeeze',
                            inputs=[node_gather_2],
                            outputs=node.output('Index'),
                            axes=[0])
Beispiel #12
0
    def __init__(self,
                 data_dir,
                 file_list,
                 label_list,
                 transforms=None,
                 num_workers='auto',
                 buffer_size=100,
                 parallel_method='process',
                 shuffle=False):
        # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
        # or matplotlib.backends is imported for the first time
        # pycocotools import matplotlib
        import matplotlib
        matplotlib.use('Agg')
        from pycocotools.coco import COCO
        super(VOCDetection, self).__init__(transforms=transforms,
                                           num_workers=num_workers,
                                           buffer_size=buffer_size,
                                           parallel_method=parallel_method,
                                           shuffle=shuffle)
        self.file_list = list()
        self.labels = list()
        self._epoch = 0

        annotations = {}
        annotations['images'] = []
        annotations['categories'] = []
        annotations['annotations'] = []

        cname2cid = OrderedDict()
        label_id = 1
        with open(label_list, 'r', encoding=get_encoding(label_list)) as fr:
            for line in fr.readlines():
                cname2cid[line.strip()] = label_id
                label_id += 1
                self.labels.append(line.strip())
        logging.info("Starting to read file list from dataset...")
        for k, v in cname2cid.items():
            annotations['categories'].append({
                'supercategory': 'component',
                'id': v,
                'name': k
            })
        ct = 0
        ann_ct = 0
        with open(file_list, 'r', encoding=get_encoding(file_list)) as fr:
            while True:
                line = fr.readline()
                if not line:
                    break
                if len(line.strip().split()) > 2:
                    raise Exception(
                        "A space is defined as the separator, but it exists in image or label name {}."
                        .format(line))
                img_file, xml_file = [osp.join(data_dir, x) \
                        for x in line.strip().split()[:2]]
                img_file = path_normalization(img_file)
                xml_file = path_normalization(xml_file)
                if not is_pic(img_file):
                    continue
                if not osp.isfile(xml_file):
                    continue
                if not osp.exists(img_file):
                    logging.warning(
                        'The image file {} is not exist!'.format(img_file))
                    continue
                if not osp.exists(xml_file):
                    logging.warning(
                        'The annotation file {} is not exist!'.format(
                            xml_file))
                    continue
                tree = ET.parse(xml_file)
                if tree.find('id') is None:
                    im_id = np.array([ct])
                else:
                    ct = int(tree.find('id').text)
                    im_id = np.array([int(tree.find('id').text)])
                pattern = re.compile('<object>', re.IGNORECASE)
                obj_match = pattern.findall(
                    str(ET.tostringlist(tree.getroot())))
                if len(obj_match) == 0:
                    continue
                obj_tag = obj_match[0][1:-1]
                objs = tree.findall(obj_tag)
                pattern = re.compile('<size>', re.IGNORECASE)
                size_tag = pattern.findall(str(ET.tostringlist(
                    tree.getroot())))[0][1:-1]
                size_element = tree.find(size_tag)
                pattern = re.compile('<width>', re.IGNORECASE)
                width_tag = pattern.findall(str(
                    ET.tostringlist(size_element)))[0][1:-1]
                im_w = float(size_element.find(width_tag).text)
                pattern = re.compile('<height>', re.IGNORECASE)
                height_tag = pattern.findall(str(
                    ET.tostringlist(size_element)))[0][1:-1]
                im_h = float(size_element.find(height_tag).text)
                gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
                gt_class = np.zeros((len(objs), 1), dtype=np.int32)
                gt_score = np.ones((len(objs), 1), dtype=np.float32)
                is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
                difficult = np.zeros((len(objs), 1), dtype=np.int32)
                for i, obj in enumerate(objs):
                    pattern = re.compile('<name>', re.IGNORECASE)
                    name_tag = pattern.findall(str(
                        ET.tostringlist(obj)))[0][1:-1]
                    cname = obj.find(name_tag).text.strip()
                    gt_class[i][0] = cname2cid[cname]
                    pattern = re.compile('<difficult>', re.IGNORECASE)
                    diff_tag = pattern.findall(str(ET.tostringlist(obj)))
                    if len(diff_tag) == 0:
                        _difficult = 0
                    else:
                        diff_tag = diff_tag[0][1:-1]
                        try:
                            _difficult = int(obj.find(diff_tag).text)
                        except Exception:
                            _difficult = 0
                    pattern = re.compile('<bndbox>', re.IGNORECASE)
                    box_tag = pattern.findall(str(ET.tostringlist(obj)))
                    if len(box_tag) == 0:
                        logging.warning(
                            "There's no field '<bndbox>' in one of object, so this object will be ignored. xml file: {}"
                            .format(xml_file))
                        continue
                    box_tag = box_tag[0][1:-1]
                    box_element = obj.find(box_tag)
                    pattern = re.compile('<xmin>', re.IGNORECASE)
                    xmin_tag = pattern.findall(
                        str(ET.tostringlist(box_element)))[0][1:-1]
                    x1 = float(box_element.find(xmin_tag).text)
                    pattern = re.compile('<ymin>', re.IGNORECASE)
                    ymin_tag = pattern.findall(
                        str(ET.tostringlist(box_element)))[0][1:-1]
                    y1 = float(box_element.find(ymin_tag).text)
                    pattern = re.compile('<xmax>', re.IGNORECASE)
                    xmax_tag = pattern.findall(
                        str(ET.tostringlist(box_element)))[0][1:-1]
                    x2 = float(box_element.find(xmax_tag).text)
                    pattern = re.compile('<ymax>', re.IGNORECASE)
                    ymax_tag = pattern.findall(
                        str(ET.tostringlist(box_element)))[0][1:-1]
                    y2 = float(box_element.find(ymax_tag).text)
                    x1 = max(0, x1)
                    y1 = max(0, y1)
                    if im_w > 0.5 and im_h > 0.5:
                        x2 = min(im_w - 1, x2)
                        y2 = min(im_h - 1, y2)
                    gt_bbox[i] = [x1, y1, x2, y2]
                    is_crowd[i][0] = 0
                    difficult[i][0] = _difficult
                    annotations['annotations'].append({
                        'iscrowd':
                        0,
                        'image_id':
                        int(im_id[0]),
                        'bbox': [x1, y1, x2 - x1 + 1, y2 - y1 + 1],
                        'area':
                        float((x2 - x1 + 1) * (y2 - y1 + 1)),
                        'category_id':
                        cname2cid[cname],
                        'id':
                        ann_ct,
                        'difficult':
                        _difficult
                    })
                    ann_ct += 1

                im_info = {
                    'im_id': im_id,
                    'image_shape': np.array([im_h, im_w]).astype('int32'),
                }
                label_info = {
                    'is_crowd': is_crowd,
                    'gt_class': gt_class,
                    'gt_bbox': gt_bbox,
                    'gt_score': gt_score,
                    'gt_poly': [],
                    'difficult': difficult
                }
                voc_rec = (im_info, label_info)
                if len(objs) != 0:
                    self.file_list.append([img_file, voc_rec])
                    ct += 1
                    annotations['images'].append({
                        'height':
                        im_h,
                        'width':
                        im_w,
                        'id':
                        int(im_id[0]),
                        'file_name':
                        osp.split(img_file)[1]
                    })

        if not len(self.file_list) > 0:
            raise Exception('not found any voc record in %s' % (file_list))
        logging.info("{} samples in file {}".format(len(self.file_list),
                                                    file_list))
        self.num_samples = len(self.file_list)
        self.coco_gt = COCO()
        self.coco_gt.dataset = annotations
        self.coco_gt.createIndex()
Beispiel #13
0
    def __init__(self,
                 data_dir,
                 ann_file,
                 transforms=None,
                 num_workers='auto',
                 buffer_size=100,
                 parallel_method='process',
                 shuffle=False):
        from pycocotools.coco import COCO

        super(VOCDetection, self).__init__(
            transforms=transforms,
            num_workers=num_workers,
            buffer_size=buffer_size,
            parallel_method=parallel_method,
            shuffle=shuffle)
        self.file_list = list()
        self.labels = list()
        self._epoch = 0

        coco = COCO(ann_file)
        self.coco_gt = coco
        img_ids = coco.getImgIds()
        cat_ids = coco.getCatIds()
        catid2clsid = dict({catid: i + 1 for i, catid in enumerate(cat_ids)})
        cname2cid = dict({
            coco.loadCats(catid)[0]['name']: clsid
            for catid, clsid in catid2clsid.items()
        })
        for label, cid in sorted(cname2cid.items(), key=lambda d: d[1]):
            self.labels.append(label)
        logging.info("Starting to read file list from dataset...")
        for img_id in img_ids:
            img_anno = coco.loadImgs(img_id)[0]
            im_fname = osp.join(data_dir, img_anno['file_name'])
            if not is_pic(im_fname):
                continue
            im_w = float(img_anno['width'])
            im_h = float(img_anno['height'])
            ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
            instances = coco.loadAnns(ins_anno_ids)

            bboxes = []
            for inst in instances:
                x, y, box_w, box_h = inst['bbox']
                x1 = max(0, x)
                y1 = max(0, y)
                x2 = min(im_w - 1, x1 + max(0, box_w - 1))
                y2 = min(im_h - 1, y1 + max(0, box_h - 1))
                if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
                    inst['clean_bbox'] = [x1, y1, x2, y2]
                    bboxes.append(inst)
                else:
                    logging.warning(
                        "Found an invalid bbox in annotations: im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}."
                        .format(img_id, float(inst['area']), x1, y1, x2, y2))
            num_bbox = len(bboxes)
            gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
            gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
            gt_score = np.ones((num_bbox, 1), dtype=np.float32)
            is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
            difficult = np.zeros((num_bbox, 1), dtype=np.int32)
            gt_poly = None

            for i, box in enumerate(bboxes):
                catid = box['category_id']
                gt_class[i][0] = catid2clsid[catid]
                gt_bbox[i, :] = box['clean_bbox']
                is_crowd[i][0] = box['iscrowd']
                if 'segmentation' in box:
                    if gt_poly is None:
                        gt_poly = [None] * num_bbox
                    gt_poly[i] = box['segmentation']

            im_info = {
                'im_id': np.array([img_id]).astype('int32'),
                'image_shape': np.array([im_h, im_w]).astype('int32'),
            }
            label_info = {
                'is_crowd': is_crowd,
                'gt_class': gt_class,
                'gt_bbox': gt_bbox,
                'gt_score': gt_score,
                'difficult': difficult
            }
            if gt_poly is not None:
                label_info['gt_poly'] = gt_poly

            coco_rec = (im_info, label_info)
            self.file_list.append([im_fname, coco_rec])
        if not len(self.file_list) > 0:
            raise Exception('not found any coco record in %s' % (ann_file))
        logging.info("{} samples in file {}".format(
            len(self.file_list), ann_file))
        self.num_samples = len(self.file_list)
Beispiel #14
0
    def net_initialize(self,
                       startup_prog=None,
                       pretrain_weights=None,
                       fuse_bn=False,
                       save_dir='.',
                       sensitivities_file=None,
                       eval_metric_loss=0.05,
                       resume_checkpoint=None):
        if not resume_checkpoint:
            pretrain_dir = osp.join(save_dir, 'pretrain')
            if not os.path.isdir(pretrain_dir):
                if os.path.exists(pretrain_dir):
                    os.remove(pretrain_dir)
                os.makedirs(pretrain_dir)
            if pretrain_weights is not None and not os.path.exists(
                    pretrain_weights):
                if self.model_type == 'classifier':
                    if pretrain_weights not in ['IMAGENET', 'BAIDU10W']:
                        logging.warning(
                            "Path of pretrain_weights('{}') is not exists!".
                            format(pretrain_weights))
                        logging.warning(
                            "Pretrain_weights will be forced to set as 'IMAGENET', if you don't want to use pretrain weights, set pretrain_weights=None."
                        )
                        pretrain_weights = 'IMAGENET'
                elif self.model_type == 'detector':
                    if pretrain_weights not in ['IMAGENET', 'COCO']:
                        logging.warning(
                            "Path of pretrain_weights('{}') is not exists!".
                            format(pretrain_weights))
                        logging.warning(
                            "Pretrain_weights will be forced to set as 'IMAGENET', if you don't want to use pretrain weights, set pretrain_weights=None."
                        )
                        pretrain_weights = 'IMAGENET'
                elif self.model_type == 'segmenter':
                    if pretrain_weights not in [
                            'IMAGENET', 'COCO', 'CITYSCAPES'
                    ]:
                        logging.warning(
                            "Path of pretrain_weights('{}') is not exists!".
                            format(pretrain_weights))
                        logging.warning(
                            "Pretrain_weights will be forced to set as 'IMAGENET', if you don't want to use pretrain weights, set pretrain_weights=None."
                        )
                        pretrain_weights = 'IMAGENET'
            if hasattr(self, 'backbone'):
                backbone = self.backbone
            else:
                backbone = self.__class__.__name__
                if backbone == "HRNet":
                    backbone = backbone + "_W{}".format(self.width)
            class_name = self.__class__.__name__
            pretrain_weights = get_pretrain_weights(
                pretrain_weights, class_name, backbone, pretrain_dir)
        if startup_prog is None:
            startup_prog = fluid.default_startup_program()
        self.exe.run(startup_prog)

        if not resume_checkpoint and pretrain_weights:
            logging.info(
                "Load pretrain weights from {}.".format(pretrain_weights),
                use_color=True)
            paddlex.utils.utils.load_pretrain_weights(
                self.exe, self.train_prog, pretrain_weights, fuse_bn)

        # 进行裁剪
        if sensitivities_file is not None:
            import paddle
            version = paddle.__version__.strip().split('.')
            if version[0] == '2' or (version[0] == '0' and
                                     hasattr(paddle, 'enable_static')):
                raise Exception(
                    'Model pruning is not ready when using paddle>=2.0.0, please downgrade paddle to 1.8.5.'
                )
            import paddleslim
            from .slim.prune_config import get_sensitivities
            sensitivities_file = get_sensitivities(sensitivities_file, self,
                                                   save_dir)
            from .slim.prune import get_params_ratios, prune_program
            logging.info(
                "Start to prune program with eval_metric_loss = {}".format(
                    eval_metric_loss),
                use_color=True)
            origin_flops = paddleslim.analysis.flops(self.test_prog)
            prune_params_ratios = get_params_ratios(
                sensitivities_file, eval_metric_loss=eval_metric_loss)
            prune_program(self, prune_params_ratios)
            current_flops = paddleslim.analysis.flops(self.test_prog)
            remaining_ratio = current_flops / origin_flops
            logging.info(
                "Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
                .format(origin_flops, current_flops, remaining_ratio),
                use_color=True)
            self.status = 'Prune'

        if resume_checkpoint:
            logging.info(
                "Resume checkpoint from {}.".format(resume_checkpoint),
                use_color=True)
            paddlex.utils.utils.load_pretrain_weights(
                self.exe, self.train_prog, resume_checkpoint, resume=True)
            if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
                raise Exception("There's not model.yml in {}".format(
                    resume_checkpoint))
            with open(osp.join(resume_checkpoint, "model.yml")) as f:
                info = yaml.load(f.read(), Loader=yaml.Loader)
                self.completed_epochs = info['completed_epochs']
Beispiel #15
0
def get_pretrain_weights(flag, class_name, backbone, save_dir):
    if flag is None:
        return None
    elif osp.isdir(flag):
        return flag
    elif osp.isfile(flag):
        return flag
    warning_info = "{} does not support to be finetuned with weights pretrained on the {} dataset, so pretrain_weights is forced to be set to {}"
    if flag == 'COCO':
        if class_name == 'DeepLabv3p' and backbone in [
                'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
                'MobileNetV2_x1.5', 'MobileNetV2_x2.0',
                'MobileNetV3_large_x1_0_ssld'
        ]:
            model_name = '{}_{}'.format(class_name, backbone)
            logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
            flag = 'IMAGENET'
        elif class_name == 'HRNet':
            logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
            flag = 'IMAGENET'
        elif class_name == 'FastSCNN':
            logging.warning(warning_info.format(class_name, flag,
                                                'CITYSCAPES'))
            flag = 'CITYSCAPES'
    elif flag == 'CITYSCAPES':
        model_name = '{}_{}'.format(class_name, backbone)
        if class_name == 'UNet':
            logging.warning(warning_info.format(class_name, flag, 'COCO'))
            flag = 'COCO'
        if class_name == 'HRNet' and backbone.split('_')[-1] in [
                'W30', 'W32', 'W40', 'W48', 'W60', 'W64'
        ]:
            logging.warning(warning_info.format(backbone, flag, 'IMAGENET'))
            flag = 'IMAGENET'
        if class_name == 'DeepLabv3p' and backbone in [
                'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
                'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
        ]:
            model_name = '{}_{}'.format(class_name, backbone)
            logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
            flag = 'IMAGENET'
    elif flag == 'IMAGENET':
        if class_name == 'UNet':
            logging.warning(warning_info.format(class_name, flag, 'COCO'))
            flag = 'COCO'
        elif class_name == 'FastSCNN':
            logging.warning(warning_info.format(class_name, flag,
                                                'CITYSCAPES'))
            flag = 'CITYSCAPES'
    elif flag == 'BAIDU10W':
        if class_name not in ['ResNet50_vd']:
            raise Exception(
                "Only the classifier ResNet50_vd supports BAIDU10W pretrained weights"
            )

    if flag == 'IMAGENET':
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        if backbone.startswith('Xception'):
            backbone = 'Seg{}'.format(backbone)
        elif backbone == 'MobileNetV2':
            backbone = 'MobileNetV2_x1.0'
        elif backbone == 'MobileNetV3_small_ssld':
            backbone = 'MobileNetV3_small_x1_0_ssld'
        elif backbone == 'MobileNetV3_large_ssld':
            backbone = 'MobileNetV3_large_x1_0_ssld'
        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
            if backbone == 'ResNet50':
                backbone = 'DetResNet50'
        assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
            backbone)

        if getattr(paddlex, 'gui_mode', False):
            url = image_pretrain[backbone]
            fname = osp.split(url)[-1].split('.')[0]
            paddlex.utils.download_and_decompress(url, path=new_save_dir)
            return osp.join(new_save_dir, fname)

        import paddlehub as hub
        try:
            logging.info(
                "Connecting PaddleHub server to get pretrain weights...")
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            logging.error(
                "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self"
                .format(image_pretrain[backbone]),
                exit=False)
            if isinstance(e, hub.ResourceNotFoundError):
                raise Exception(
                    "Resource for backbone {} not found".format(backbone))
            elif isinstance(e, hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connection"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
    elif flag in ['COCO', 'CITYSCAPES']:
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        if class_name in [
                'YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p', 'PPYOLO'
        ]:
            backbone = '{}_{}'.format(class_name, backbone)
        backbone = "{}_{}".format(backbone, flag)
        if flag == 'COCO':
            url = coco_pretrain[backbone]
        elif flag == 'CITYSCAPES':
            url = cityscapes_pretrain[backbone]
        fname = osp.split(url)[-1].split('.')[0]

        if getattr(paddlex, 'gui_mode', False):
            paddlex.utils.download_and_decompress(url, path=new_save_dir)
            return osp.join(new_save_dir, fname)

        import paddlehub as hub
        try:
            logging.info(
                "Connecting PaddleHub server to get pretrain weights...")
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            logging.error(
                "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self"
                .format(url),
                exit=False)
            if isinstance(e, hub.ResourceNotFoundError):
                raise Exception(
                    "Resource for backbone {} not found".format(backbone))
            elif isinstance(e, hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connection"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
    elif flag == 'BAIDU10W':
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        backbone = backbone + '_BAIDU10W'
        url = baidu10w_pretrain[backbone]
        fname = osp.split(url)[-1].split('.')[0]

        if getattr(paddlex, 'gui_mode', False):
            paddlex.utils.download_and_decompress(url, path=new_save_dir)
            return osp.join(new_save_dir, fname)

        import paddlehub as hub
        try:
            logging.info(
                "Connecting PaddleHub server to get pretrain weights...")
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            logging.error(
                "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self"
                .format(url),
                exit=False)
            if isinstance(e, hub.ResourceNotFoundError):
                raise Exception(
                    "Resource for backbone {} not found".format(backbone))
            elif isinstance(e, hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connection"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
    else:
        logging.error(
            "Path of retrain weights '{}' is not exists!".format(flag))