def __init__(self, model, config):
     super(DetBenchTrain, self).__init__()
     self.config = config
     self.model = model
     self.anchors = Anchors(
         config.min_level, config.max_level,
         config.num_scales, config.aspect_ratios,
         config.anchor_scale, config.image_size)
     self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5)
     self.loss_fn = DetectionLoss(self.config)
    def __init__(self, model_name, num_classes, create_labeler=True):
        super().__init__()
        self.save_hyperparameters()
        config = get_efficientdet_config(model_name)

        self.model = EfficientDet(config)
        self.model.reset_head(num_classes=num_classes)
        self.num_levels = self.model.config.num_levels
        self.num_classes = self.model.config.num_classes
        self.anchors = Anchors.from_config(self.model.config)
        self.anchor_labeler = None
        if create_labeler:
            self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
        self.loss_fn = DetectionLoss(self.model.config)
class DetBenchTrain(nn.Module):
    def __init__(self, model, config):
        super(DetBenchTrain, self).__init__()
        self.config = config
        self.model = model
        self.anchors = Anchors(
            config.min_level, config.max_level,
            config.num_scales, config.aspect_ratios,
            config.anchor_scale, config.image_size)
        self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5)
        self.loss_fn = DetectionLoss(self.config)

    def forward(self, x, target):
        class_out, box_out = self.model(x)
        cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
            x.shape[0], target['bbox'], target['cls'])
        loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
        output = dict(loss=loss, class_loss=class_loss, box_loss=box_loss)
        if not self.training:
            # if eval mode, output detections for evaluation
            class_out, box_out, indices, classes = _post_process(self.config, class_out, box_out)
            if self.config.custom_nms:
                min_score = self.config.nms_min_score
                iou_threshold = self.config.nms_max_iou
                detections = _batch_detection_cluster_nms(
                    x.shape[0], class_out, box_out, self.anchors.boxes, indices,
                    target['img_scale'], target['img_size'], min_score, iou_threshold)
            else:
                detections = _batch_detection(
                    x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes,
                    target['img_scale'], target['img_size'])
            output['detections'] = detections
        return output
示例#4
0
def create_datasets_and_loaders(args, model_config):
    input_config = resolve_input_config(args, model_config=model_config)

    dataset_train, dataset_eval = create_dataset(args.dataset, args.root)

    # setup labeler in loader/collate_fn if not enabled in the model bench
    labeler = None
    if not args.bench_labeler:
        labeler = AnchorLabeler(Anchors.from_config(model_config),
                                model_config.num_classes,
                                match_threshold=0.5)

    loader_train = create_loader(
        dataset_train,
        input_size=input_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        # color_jitter=args.color_jitter,
        # auto_augment=args.aa,
        interpolation=args.train_interpolation
        or input_config['interpolation'],
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        anchor_labeler=labeler,
    )

    if args.val_skip > 1:
        dataset_eval = SkipSubset(dataset_eval, args.val_skip)
    loader_eval = create_loader(
        dataset_eval,
        input_size=input_config['input_size'],
        batch_size=args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=input_config['interpolation'],
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        anchor_labeler=labeler,
    )

    evaluator = create_evaluator(args.dataset,
                                 loader_eval.dataset,
                                 distributed=args.distributed,
                                 pred_yxyx=False)

    return loader_train, loader_eval, evaluator
示例#5
0
    def build_training_data_loader(self):

        if self.context.get_hparam("fake_data"):
            dataset_train = FakeBackend()
            self.dataset_eval = dataset_train

        else:
            dataset_train, self.dataset_eval = create_dataset(
                self.args.dataset, self.args.root)

        self.labeler = None
        if not self.args.bench_labeler:
            self.labeler = AnchorLabeler(Anchors.from_config(
                self.model_config),
                                         self.model_config.num_classes,
                                         match_threshold=0.5)

        loader_train = self._create_loader(
            dataset_train,
            input_size=self.input_config['input_size'],
            batch_size=self.context.get_per_slot_batch_size(),
            is_training=True,
            use_prefetcher=self.args.prefetcher,
            re_prob=self.args.reprob,
            re_mode=self.args.remode,
            re_count=self.args.recount,
            # color_jitter=self.args.color_jitter,
            # auto_augment=self.args.aa,
            interpolation=self.args.train_interpolation
            or self.input_config['interpolation'],
            fill_color=self.input_config['fill_color'],
            mean=self.input_config['mean'],
            std=self.input_config['std'],
            num_workers=1,  #self.args.workers,
            distributed=self.args.distributed,
            pin_mem=self.args.pin_mem,
            anchor_labeler=self.labeler,
        )

        if not self.context.get_hparam(
                "fake_data"
        ) and self.model_config.num_classes < loader_train.dataset.parser.max_label:
            logging.error(
                f'Model {self.model_config.num_classes} has fewer classes than dataset {loader_train.dataset.parser.max_label}.'
            )
            sys.exit(1)
        if not self.context.get_hparam(
                "fake_data"
        ) and self.model_config.num_classes > loader_train.dataset.parser.max_label:
            logging.warning(
                f'Model {self.model_config.num_classes} has more classes than dataset {loader_train.dataset.parser.max_label}.'
            )

        self.data_length = len(loader_train)

        return loader_train
示例#6
0
 def __init__(self, root, ann_file, config, transform=None):
     super(CocoDetection, self).__init__()
     if isinstance(root, torch._six.string_classes):
         root = os.path.expanduser(root)
     self.root = root
     self.transform = transform
     self.yxyx = True  # expected for TF model, most PT are xyxy
     self.include_masks = False
     self.include_bboxes_ignore = False
     self.has_annotations = 'image_info' not in ann_file
     self.coco = None
     self.cat_ids = []
     self.cat_to_label = dict()
     self.img_ids = []
     self.img_ids_invalid = []
     self.img_infos = []
     self._load_annotations(ann_file)
     self.anchors = Anchors(config.min_level, config.max_level,
                            config.num_scales, config.aspect_ratios,
                            config.anchor_scale, config.image_size)
     self.anchor_labeler = AnchorLabeler(self.anchors,
                                         config.num_classes,
                                         match_threshold=0.5)
示例#7
0
class CocoDetection(data.Dataset):
    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
    Args:
        root (string): Root directory where images are downloaded to.
        ann_file (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``

    """
    def __init__(self, root, ann_file, config, transform=None):
        super(CocoDetection, self).__init__()
        if isinstance(root, torch._six.string_classes):
            root = os.path.expanduser(root)
        self.root = root
        self.transform = transform
        self.yxyx = True  # expected for TF model, most PT are xyxy
        self.include_masks = False
        self.include_bboxes_ignore = False
        self.has_annotations = 'image_info' not in ann_file
        self.coco = None
        self.cat_ids = []
        self.cat_to_label = dict()
        self.img_ids = []
        self.img_ids_invalid = []
        self.img_infos = []
        self._load_annotations(ann_file)
        self.anchors = Anchors(config.min_level, config.max_level,
                               config.num_scales, config.aspect_ratios,
                               config.anchor_scale, config.image_size)
        self.anchor_labeler = AnchorLabeler(self.anchors,
                                            config.num_classes,
                                            match_threshold=0.5)

    def _load_annotations(self, ann_file):
        assert self.coco is None
        self.coco = COCO(ann_file)
        self.cat_ids = self.coco.getCatIds()
        img_ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
        for img_id in sorted(self.coco.imgs.keys()):
            info = self.coco.loadImgs([img_id])[0]
            valid_annotation = not self.has_annotations or img_id in img_ids_with_ann
            if valid_annotation and min(info['width'], info['height']) >= 32:
                self.img_ids.append(img_id)
                self.img_infos.append(info)
            else:
                self.img_ids_invalid.append(img_id)

    def _parse_img_ann(self, img_id, img_info):
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        ann_info = self.coco.loadAnns(ann_ids)
        bboxes = []
        bboxes_ignore = []
        cls = []

        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if self.include_masks and ann['area'] <= 0:
                continue
            if w < 1 or h < 1:
                continue

            # To subtract 1 or not, TF doesn't appear to do this so will keep it out for now.
            if self.yxyx:
                #bbox = [y1, x1, y1 + h - 1, x1 + w - 1]
                bbox = [y1, x1, y1 + h, x1 + w]
            else:
                #bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
                bbox = [x1, y1, x1 + w, y1 + h]

            if ann.get('iscrowd', False):
                if self.include_bboxes_ignore:
                    bboxes_ignore.append(bbox)
            else:
                bboxes.append(bbox)
                cls.append(self.cat_to_label[ann['category_id']] if self.
                           cat_to_label else ann['category_id'])

        if bboxes:
            bboxes = np.array(bboxes, dtype=np.float32)
            cls = np.array(cls, dtype=np.int64)
        else:
            bboxes = np.zeros((0, 4), dtype=np.float32)
            cls = np.array([], dtype=np.int64)

        if self.include_bboxes_ignore:
            if bboxes_ignore:
                bboxes_ignore = np.array(bboxes_ignore, dtype=np.float32)
            else:
                bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        ann = dict(img_id=img_id,
                   bbox=bboxes,
                   cls=cls,
                   img_size=(img_info['width'], img_info['height']))

        if self.include_bboxes_ignore:
            ann['bbox_ignore'] = bboxes_ignore

        return ann

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: Tuple (image, annotations (target)).
        """
        img_id = self.img_ids[index]
        img_info = self.img_infos[index]
        if self.has_annotations:
            ann = self._parse_img_ann(img_id, img_info)
        else:
            ann = dict(img_id=img_id,
                       img_size=(img_info['width'], img_info['height']))

        path = img_info['file_name']
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            img, ann = self.transform(img, ann)

        cls_targets, box_targets, num_positives = self.anchor_labeler.label_anchors(
            ann['bbox'], ann['cls'])
        ann.pop('bbox')
        ann.pop('cls')
        ann['num_positives'] = num_positives
        ann.update(cls_targets)
        ann.update(box_targets)

        return img, ann

    def __len__(self):
        return len(self.img_ids)
示例#8
0
def create_datasets_and_loaders(
    args,
    model_config,
    transform_train_fn=None,
    transform_eval_fn=None,
    collate_fn=None,
):
    """ Setup datasets, transforms, loaders, evaluator.

    Args:
        args: Command line args / config for training
        model_config: Model specific configuration dict / struct
        transform_train_fn: Override default image + annotation transforms (see note in loaders.py)
        transform_eval_fn: Override default image + annotation transforms (see note in loaders.py)
        collate_fn: Override default fast collate function

    Returns:
        Train loader, validation loader, evaluator
    """
    input_config = resolve_input_config(args, model_config=model_config)

    dataset_train, dataset_eval = create_dataset(args.dataset, args.root)

    # setup labeler in loader/collate_fn if not enabled in the model bench
    labeler = None
    if not args.bench_labeler:
        labeler = AnchorLabeler(Anchors.from_config(model_config),
                                model_config.num_classes,
                                match_threshold=0.5)

    loader_train = create_loader(
        dataset_train,
        input_size=input_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        # color_jitter=args.color_jitter,
        # auto_augment=args.aa,
        interpolation=args.train_interpolation
        or input_config['interpolation'],
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        anchor_labeler=labeler,
        transform_fn=transform_train_fn,
        collate_fn=collate_fn,
    )

    if args.val_skip > 1:
        dataset_eval = SkipSubset(dataset_eval, args.val_skip)
    loader_eval = create_loader(
        dataset_eval,
        input_size=input_config['input_size'],
        batch_size=args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=input_config['interpolation'],
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        anchor_labeler=labeler,
        transform_fn=transform_eval_fn,
        collate_fn=collate_fn,
    )

    evaluator = create_evaluator(args.dataset,
                                 loader_eval.dataset,
                                 distributed=args.distributed,
                                 pred_yxyx=False)

    return loader_train, loader_eval, evaluator
class EfficientDetTrain(pl.LightningModule):
    def __init__(self, model_name, num_classes, create_labeler=True):
        super().__init__()
        self.save_hyperparameters()
        config = get_efficientdet_config(model_name)

        self.model = EfficientDet(config)
        self.model.reset_head(num_classes=num_classes)
        self.num_levels = self.model.config.num_levels
        self.num_classes = self.model.config.num_classes
        self.anchors = Anchors.from_config(self.model.config)
        self.anchor_labeler = None
        if create_labeler:
            self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
        self.loss_fn = DetectionLoss(self.model.config)

    def forward(self, x):
        class_out, box_out = self.model(x)
        return class_out, box_out
    def training_step(self, batch, batch_idx):
        x, targets, idx = batch
        x = torch.stack(x, dim = 0)
        class_out, box_out = self.forward(x)

        bbox = [tar['bbox'].float() for tar in targets]
        clses = [tar['cls'].float() for tar in targets]
        target = {}
        target['bbox'] = bbox
        target['cls'] = clses
        if self.anchor_labeler is None:
            # target should contain pre-computed anchor labels if labeler not present in bench
            assert 'label_num_positives' in target
            cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)]
            box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)]
            num_positives = target['label_num_positives']
        else:
            cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
                target['bbox'], target['cls'])

        loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
        return loss
    def validation_step(self, batch, batch_idx):
        x, targets, idx = batch
        x = torch.stack(x, dim = 0)
        class_out, box_out = self.forward(x)

        bbox = [tar['bbox'].float() for tar in targets]
        clses = [tar['cls'].float() for tar in targets]
        target = {}
        target['bbox'] = bbox
        target['cls'] = clses
        if self.anchor_labeler is None:
            # target should contain pre-computed anchor labels if labeler not present in bench
            assert 'label_num_positives' in target
            cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)]
            box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)]
            num_positives = target['label_num_positives']
        else:
            cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
                target['bbox'], target['cls'])

        loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
        return {'val_loss': loss, "class_loss": class_loss, 'box_loss': box_loss}

    def configure_optimizers(self):
        # optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        optimizer = torch.optim.SGD(self.parameters(), lr = 1e-4, momentum=0.9, weight_decay = 4e-5)
        return optimizer