Exemplo n.º 1
0
def create_datasets_and_loaders(args, model_config):
    input_config = resolve_input_config(args, model_config=model_config)

    dataset_train, dataset_eval = create_dataset(args.dataset, args.root)

    # setup labeler in loader/collate_fn if not enabled in the model bench
    labeler = None
    if not args.bench_labeler:
        labeler = AnchorLabeler(Anchors.from_config(model_config),
                                model_config.num_classes,
                                match_threshold=0.5)

    loader_train = create_loader(
        dataset_train,
        input_size=input_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        # color_jitter=args.color_jitter,
        # auto_augment=args.aa,
        interpolation=args.train_interpolation
        or input_config['interpolation'],
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        anchor_labeler=labeler,
    )

    if args.val_skip > 1:
        dataset_eval = SkipSubset(dataset_eval, args.val_skip)
    loader_eval = create_loader(
        dataset_eval,
        input_size=input_config['input_size'],
        batch_size=args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=input_config['interpolation'],
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        anchor_labeler=labeler,
    )

    evaluator = create_evaluator(args.dataset,
                                 loader_eval.dataset,
                                 distributed=args.distributed,
                                 pred_yxyx=False)

    return loader_train, loader_eval, evaluator
Exemplo n.º 2
0
    def build_training_data_loader(self):

        if self.context.get_hparam("fake_data"):
            dataset_train = FakeBackend()
            self.dataset_eval = dataset_train

        else:
            dataset_train, self.dataset_eval = create_dataset(
                self.args.dataset, self.args.root)

        self.labeler = None
        if not self.args.bench_labeler:
            self.labeler = AnchorLabeler(Anchors.from_config(
                self.model_config),
                                         self.model_config.num_classes,
                                         match_threshold=0.5)

        loader_train = self._create_loader(
            dataset_train,
            input_size=self.input_config['input_size'],
            batch_size=self.context.get_per_slot_batch_size(),
            is_training=True,
            use_prefetcher=self.args.prefetcher,
            re_prob=self.args.reprob,
            re_mode=self.args.remode,
            re_count=self.args.recount,
            # color_jitter=self.args.color_jitter,
            # auto_augment=self.args.aa,
            interpolation=self.args.train_interpolation
            or self.input_config['interpolation'],
            fill_color=self.input_config['fill_color'],
            mean=self.input_config['mean'],
            std=self.input_config['std'],
            num_workers=1,  #self.args.workers,
            distributed=self.args.distributed,
            pin_mem=self.args.pin_mem,
            anchor_labeler=self.labeler,
        )

        if not self.context.get_hparam(
                "fake_data"
        ) and self.model_config.num_classes < loader_train.dataset.parser.max_label:
            logging.error(
                f'Model {self.model_config.num_classes} has fewer classes than dataset {loader_train.dataset.parser.max_label}.'
            )
            sys.exit(1)
        if not self.context.get_hparam(
                "fake_data"
        ) and self.model_config.num_classes > loader_train.dataset.parser.max_label:
            logging.warning(
                f'Model {self.model_config.num_classes} has more classes than dataset {loader_train.dataset.parser.max_label}.'
            )

        self.data_length = len(loader_train)

        return loader_train
 def __init__(self, model, config):
     super(DetBenchTrain, self).__init__()
     self.config = config
     self.model = model
     self.anchors = Anchors(
         config.min_level, config.max_level,
         config.num_scales, config.aspect_ratios,
         config.anchor_scale, config.image_size)
     self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5)
     self.loss_fn = DetectionLoss(self.config)
    def __init__(self, model_name, num_classes, create_labeler=True):
        super().__init__()
        self.save_hyperparameters()
        config = get_efficientdet_config(model_name)

        self.model = EfficientDet(config)
        self.model.reset_head(num_classes=num_classes)
        self.num_levels = self.model.config.num_levels
        self.num_classes = self.model.config.num_classes
        self.anchors = Anchors.from_config(self.model.config)
        self.anchor_labeler = None
        if create_labeler:
            self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
        self.loss_fn = DetectionLoss(self.model.config)
Exemplo n.º 5
0
def post_processing(img, outputs):
	cls_outputs=[torch.from_numpy(outputs[0]), torch.from_numpy(outputs[1]), torch.from_numpy(outputs[2]), torch.from_numpy(outputs[3]), torch.from_numpy(outputs[4])]
	box_outputs=[torch.from_numpy(outputs[5]), torch.from_numpy(outputs[6]), torch.from_numpy(outputs[7]), torch.from_numpy(outputs[8]), torch.from_numpy(outputs[9])]
	width, height = img.size
	target_size = (512, 512)
	img_scale_y = target_size[0]/height
	img_scale_x = target_size[1]/width
	img_size = (width, height)
	img_size = torch.tensor(([[img_size[0], img_size[1]]]), dtype=torch.int32)
	img_scale = min(img_scale_y, img_scale_x)
	img_scale = torch.tensor(([img_scale]), dtype=torch.float32)
	class_out, box_out, indices, classes = _post_process(cls_outputs, box_outputs, num_levels=5, num_classes=6, max_detection_points=5000)
	config = default_detection_model_configs()
	anchors = Anchors.from_config(config)
	detection = _batch_detection(1, class_out, box_out, anchors.boxes, indices, classes, img_scale, img_size, max_det_per_image=10, soft_nms=True)
	return detection
Exemplo n.º 6
0
 def __init__(self, root, ann_file, config, transform=None):
     super(CocoDetection, self).__init__()
     if isinstance(root, torch._six.string_classes):
         root = os.path.expanduser(root)
     self.root = root
     self.transform = transform
     self.yxyx = True  # expected for TF model, most PT are xyxy
     self.include_masks = False
     self.include_bboxes_ignore = False
     self.has_annotations = 'image_info' not in ann_file
     self.coco = None
     self.cat_ids = []
     self.cat_to_label = dict()
     self.img_ids = []
     self.img_ids_invalid = []
     self.img_infos = []
     self._load_annotations(ann_file)
     self.anchors = Anchors(config.min_level, config.max_level,
                            config.num_scales, config.aspect_ratios,
                            config.anchor_scale, config.image_size)
     self.anchor_labeler = AnchorLabeler(self.anchors,
                                         config.num_classes,
                                         match_threshold=0.5)
Exemplo n.º 7
0
def create_datasets_and_loaders(
    args,
    model_config,
    transform_train_fn=None,
    transform_eval_fn=None,
    collate_fn=None,
):
    """ Setup datasets, transforms, loaders, evaluator.

    Args:
        args: Command line args / config for training
        model_config: Model specific configuration dict / struct
        transform_train_fn: Override default image + annotation transforms (see note in loaders.py)
        transform_eval_fn: Override default image + annotation transforms (see note in loaders.py)
        collate_fn: Override default fast collate function

    Returns:
        Train loader, validation loader, evaluator
    """
    input_config = resolve_input_config(args, model_config=model_config)

    dataset_train, dataset_eval = create_dataset(args.dataset, args.root)

    # setup labeler in loader/collate_fn if not enabled in the model bench
    labeler = None
    if not args.bench_labeler:
        labeler = AnchorLabeler(Anchors.from_config(model_config),
                                model_config.num_classes,
                                match_threshold=0.5)

    loader_train = create_loader(
        dataset_train,
        input_size=input_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        # color_jitter=args.color_jitter,
        # auto_augment=args.aa,
        interpolation=args.train_interpolation
        or input_config['interpolation'],
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        anchor_labeler=labeler,
        transform_fn=transform_train_fn,
        collate_fn=collate_fn,
    )

    if args.val_skip > 1:
        dataset_eval = SkipSubset(dataset_eval, args.val_skip)
    loader_eval = create_loader(
        dataset_eval,
        input_size=input_config['input_size'],
        batch_size=args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=input_config['interpolation'],
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        anchor_labeler=labeler,
        transform_fn=transform_eval_fn,
        collate_fn=collate_fn,
    )

    evaluator = create_evaluator(args.dataset,
                                 loader_eval.dataset,
                                 distributed=args.distributed,
                                 pred_yxyx=False)

    return loader_train, loader_eval, evaluator