示例#1
0
    def init_param(self, model_config):
        self.in_channels = model_config['din']
        self.post_nms_topN = model_config['post_nms_topN']
        self.pre_nms_topN = model_config['pre_nms_topN']
        self.nms_thresh = model_config['nms_thresh']
        self.use_focal_loss = model_config['use_focal_loss']

        # anchor generator
        self.anchor_generator = anchor_generators.build(
            model_config['anchor_generator_config'])
        self.num_anchors = self.anchor_generator.num_anchors

        self.instance = Instance(model_config['instance'])
        self.sampler = samplers.build(model_config['sampler_config'])
示例#2
0
def generate_anchors():
    anchor_generator_config = {
        "type": "default",
        "anchor_offset": [0, 0],
        "anchor_stride": [16, 16],
        "aspect_ratios": [0.5, 0.8, 1],
        "base_anchor_size": 16,
        "scales": [2, 4, 8, 16]
    }
    anchor_generator = anchor_generators.build(anchor_generator_config)
    feature_map_list = [(24, 80)]
    input_size = [384, 1280]
    anchors = anchor_generator.generate(feature_map_list, input_size)
    return anchors
示例#3
0
    def init_param(self, model_config):
        self.in_channels = model_config['din']
        self.post_nms_topN = model_config['post_nms_topN']
        self.pre_nms_topN = model_config['pre_nms_topN']
        self.nms_thresh = model_config['nms_thresh']
        self.use_focal_loss = model_config['use_focal_loss']

        # anchor generator
        self.anchor_generator = anchor_generators.build(
            model_config['anchor_generator_config'])
        self.num_anchors = self.anchor_generator.num_anchors
        self.nc_bbox_out = 4 * self.num_anchors
        self.nc_score_out = self.num_anchors * 2

        self.target_generators = TargetGenerator(
            model_config['target_generator_config'])
示例#4
0
    def init_param(self, model_config):
        # including bg
        self.num_classes = len(model_config['classes']) + 1
        self.in_channels = model_config.get('in_channels', 128)
        self.num_regress = model_config.get('num_regress', 4)
        self.feature_extractor_config = model_config[
            'feature_extractor_config']

        self.target_generators = TargetGenerator(
            model_config['target_generator_config'])
        self.anchor_generator = anchor_generators.build(
            model_config['anchor_generator_config'])

        self.num_anchors = self.anchor_generator.num_anchors
        input_size = torch.tensor(model_config['input_size']).float()
        self.anchors = self.anchor_generator.generate(input_size)

        self.use_focal_loss = model_config['use_focal_loss']
示例#5
0
    def __init__(self, dataset_config, transform=None, training=True):
        super().__init__(training)
        # import ipdb
        # ipdb.set_trace()
        self.transforms = transform
        self.classes = ['bg'] + dataset_config['classes']

        if dataset_config.get('img_dir') is not None:
            self.image_dir = dataset_config['img_dir']
            # directory
            self.sample_names = sorted(
                self.load_sample_names_from_image_dir(self.image_dir))
            self.imgs = self.sample_names
        elif dataset_config.get('demo_file') is not None:
            # file
            self.sample_names = sorted([dataset_config['demo_file']])
            self.imgs = self.sample_names
        else:
            # val dataset
            self.root_path = dataset_config['root_path']
            self.data_path = os.path.join(self.root_path,
                                          dataset_config['data_path'])
            self.label_path = os.path.join(self.root_path,
                                           dataset_config['label_path'])

            self.sample_names = self.make_label_list(
                os.path.join(self.label_path, dataset_config['dataset_file']))
            self.imgs = self.make_image_list()

        self.max_num_boxes = 100
        # self.default_boxes = RetinaPriorBox()(dataset_config['anchor_config'])
        self.anchor_generator = anchor_generators.build(
            dataset_config['anchor_generator_config'])
        default_boxes = self.anchor_generator.generate(
            dataset_config['input_shape'], normalize=True)

        self.default_boxes = geometry_utils.torch_xyxy_to_xywh(
            default_boxes)[0]