def __init__(self, model, config): super(DetBenchTrain, self).__init__() self.config = config self.model = model self.anchors = Anchors( config.min_level, config.max_level, config.num_scales, config.aspect_ratios, config.anchor_scale, config.image_size) self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5) self.loss_fn = DetectionLoss(self.config)
def __init__(self, model_name, num_classes, create_labeler=True): super().__init__() self.save_hyperparameters() config = get_efficientdet_config(model_name) self.model = EfficientDet(config) self.model.reset_head(num_classes=num_classes) self.num_levels = self.model.config.num_levels self.num_classes = self.model.config.num_classes self.anchors = Anchors.from_config(self.model.config) self.anchor_labeler = None if create_labeler: self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5) self.loss_fn = DetectionLoss(self.model.config)
class DetBenchTrain(nn.Module): def __init__(self, model, config): super(DetBenchTrain, self).__init__() self.config = config self.model = model self.anchors = Anchors( config.min_level, config.max_level, config.num_scales, config.aspect_ratios, config.anchor_scale, config.image_size) self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5) self.loss_fn = DetectionLoss(self.config) def forward(self, x, target): class_out, box_out = self.model(x) cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors( x.shape[0], target['bbox'], target['cls']) loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives) output = dict(loss=loss, class_loss=class_loss, box_loss=box_loss) if not self.training: # if eval mode, output detections for evaluation class_out, box_out, indices, classes = _post_process(self.config, class_out, box_out) if self.config.custom_nms: min_score = self.config.nms_min_score iou_threshold = self.config.nms_max_iou detections = _batch_detection_cluster_nms( x.shape[0], class_out, box_out, self.anchors.boxes, indices, target['img_scale'], target['img_size'], min_score, iou_threshold) else: detections = _batch_detection( x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, target['img_scale'], target['img_size']) output['detections'] = detections return output
def create_datasets_and_loaders(args, model_config): input_config = resolve_input_config(args, model_config=model_config) dataset_train, dataset_eval = create_dataset(args.dataset, args.root) # setup labeler in loader/collate_fn if not enabled in the model bench labeler = None if not args.bench_labeler: labeler = AnchorLabeler(Anchors.from_config(model_config), model_config.num_classes, match_threshold=0.5) loader_train = create_loader( dataset_train, input_size=input_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, # color_jitter=args.color_jitter, # auto_augment=args.aa, interpolation=args.train_interpolation or input_config['interpolation'], fill_color=input_config['fill_color'], mean=input_config['mean'], std=input_config['std'], num_workers=args.workers, distributed=args.distributed, pin_mem=args.pin_mem, anchor_labeler=labeler, ) if args.val_skip > 1: dataset_eval = SkipSubset(dataset_eval, args.val_skip) loader_eval = create_loader( dataset_eval, input_size=input_config['input_size'], batch_size=args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=input_config['interpolation'], fill_color=input_config['fill_color'], mean=input_config['mean'], std=input_config['std'], num_workers=args.workers, distributed=args.distributed, pin_mem=args.pin_mem, anchor_labeler=labeler, ) evaluator = create_evaluator(args.dataset, loader_eval.dataset, distributed=args.distributed, pred_yxyx=False) return loader_train, loader_eval, evaluator
def build_training_data_loader(self): if self.context.get_hparam("fake_data"): dataset_train = FakeBackend() self.dataset_eval = dataset_train else: dataset_train, self.dataset_eval = create_dataset( self.args.dataset, self.args.root) self.labeler = None if not self.args.bench_labeler: self.labeler = AnchorLabeler(Anchors.from_config( self.model_config), self.model_config.num_classes, match_threshold=0.5) loader_train = self._create_loader( dataset_train, input_size=self.input_config['input_size'], batch_size=self.context.get_per_slot_batch_size(), is_training=True, use_prefetcher=self.args.prefetcher, re_prob=self.args.reprob, re_mode=self.args.remode, re_count=self.args.recount, # color_jitter=self.args.color_jitter, # auto_augment=self.args.aa, interpolation=self.args.train_interpolation or self.input_config['interpolation'], fill_color=self.input_config['fill_color'], mean=self.input_config['mean'], std=self.input_config['std'], num_workers=1, #self.args.workers, distributed=self.args.distributed, pin_mem=self.args.pin_mem, anchor_labeler=self.labeler, ) if not self.context.get_hparam( "fake_data" ) and self.model_config.num_classes < loader_train.dataset.parser.max_label: logging.error( f'Model {self.model_config.num_classes} has fewer classes than dataset {loader_train.dataset.parser.max_label}.' ) sys.exit(1) if not self.context.get_hparam( "fake_data" ) and self.model_config.num_classes > loader_train.dataset.parser.max_label: logging.warning( f'Model {self.model_config.num_classes} has more classes than dataset {loader_train.dataset.parser.max_label}.' ) self.data_length = len(loader_train) return loader_train
def __init__(self, root, ann_file, config, transform=None): super(CocoDetection, self).__init__() if isinstance(root, torch._six.string_classes): root = os.path.expanduser(root) self.root = root self.transform = transform self.yxyx = True # expected for TF model, most PT are xyxy self.include_masks = False self.include_bboxes_ignore = False self.has_annotations = 'image_info' not in ann_file self.coco = None self.cat_ids = [] self.cat_to_label = dict() self.img_ids = [] self.img_ids_invalid = [] self.img_infos = [] self._load_annotations(ann_file) self.anchors = Anchors(config.min_level, config.max_level, config.num_scales, config.aspect_ratios, config.anchor_scale, config.image_size) self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5)
class CocoDetection(data.Dataset): """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset. Args: root (string): Root directory where images are downloaded to. ann_file (string): Path to json annotation file. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.ToTensor`` """ def __init__(self, root, ann_file, config, transform=None): super(CocoDetection, self).__init__() if isinstance(root, torch._six.string_classes): root = os.path.expanduser(root) self.root = root self.transform = transform self.yxyx = True # expected for TF model, most PT are xyxy self.include_masks = False self.include_bboxes_ignore = False self.has_annotations = 'image_info' not in ann_file self.coco = None self.cat_ids = [] self.cat_to_label = dict() self.img_ids = [] self.img_ids_invalid = [] self.img_infos = [] self._load_annotations(ann_file) self.anchors = Anchors(config.min_level, config.max_level, config.num_scales, config.aspect_ratios, config.anchor_scale, config.image_size) self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5) def _load_annotations(self, ann_file): assert self.coco is None self.coco = COCO(ann_file) self.cat_ids = self.coco.getCatIds() img_ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) for img_id in sorted(self.coco.imgs.keys()): info = self.coco.loadImgs([img_id])[0] valid_annotation = not self.has_annotations or img_id in img_ids_with_ann if valid_annotation and min(info['width'], info['height']) >= 32: self.img_ids.append(img_id) self.img_infos.append(info) else: self.img_ids_invalid.append(img_id) def _parse_img_ann(self, img_id, img_info): ann_ids = self.coco.getAnnIds(imgIds=[img_id]) ann_info = self.coco.loadAnns(ann_ids) bboxes = [] bboxes_ignore = [] cls = [] for i, ann in enumerate(ann_info): if ann.get('ignore', False): continue x1, y1, w, h = ann['bbox'] if self.include_masks and ann['area'] <= 0: continue if w < 1 or h < 1: continue # To subtract 1 or not, TF doesn't appear to do this so will keep it out for now. if self.yxyx: #bbox = [y1, x1, y1 + h - 1, x1 + w - 1] bbox = [y1, x1, y1 + h, x1 + w] else: #bbox = [x1, y1, x1 + w - 1, y1 + h - 1] bbox = [x1, y1, x1 + w, y1 + h] if ann.get('iscrowd', False): if self.include_bboxes_ignore: bboxes_ignore.append(bbox) else: bboxes.append(bbox) cls.append(self.cat_to_label[ann['category_id']] if self. cat_to_label else ann['category_id']) if bboxes: bboxes = np.array(bboxes, dtype=np.float32) cls = np.array(cls, dtype=np.int64) else: bboxes = np.zeros((0, 4), dtype=np.float32) cls = np.array([], dtype=np.int64) if self.include_bboxes_ignore: if bboxes_ignore: bboxes_ignore = np.array(bboxes_ignore, dtype=np.float32) else: bboxes_ignore = np.zeros((0, 4), dtype=np.float32) ann = dict(img_id=img_id, bbox=bboxes, cls=cls, img_size=(img_info['width'], img_info['height'])) if self.include_bboxes_ignore: ann['bbox_ignore'] = bboxes_ignore return ann def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: Tuple (image, annotations (target)). """ img_id = self.img_ids[index] img_info = self.img_infos[index] if self.has_annotations: ann = self._parse_img_ann(img_id, img_info) else: ann = dict(img_id=img_id, img_size=(img_info['width'], img_info['height'])) path = img_info['file_name'] img = Image.open(os.path.join(self.root, path)).convert('RGB') if self.transform is not None: img, ann = self.transform(img, ann) cls_targets, box_targets, num_positives = self.anchor_labeler.label_anchors( ann['bbox'], ann['cls']) ann.pop('bbox') ann.pop('cls') ann['num_positives'] = num_positives ann.update(cls_targets) ann.update(box_targets) return img, ann def __len__(self): return len(self.img_ids)
def create_datasets_and_loaders( args, model_config, transform_train_fn=None, transform_eval_fn=None, collate_fn=None, ): """ Setup datasets, transforms, loaders, evaluator. Args: args: Command line args / config for training model_config: Model specific configuration dict / struct transform_train_fn: Override default image + annotation transforms (see note in loaders.py) transform_eval_fn: Override default image + annotation transforms (see note in loaders.py) collate_fn: Override default fast collate function Returns: Train loader, validation loader, evaluator """ input_config = resolve_input_config(args, model_config=model_config) dataset_train, dataset_eval = create_dataset(args.dataset, args.root) # setup labeler in loader/collate_fn if not enabled in the model bench labeler = None if not args.bench_labeler: labeler = AnchorLabeler(Anchors.from_config(model_config), model_config.num_classes, match_threshold=0.5) loader_train = create_loader( dataset_train, input_size=input_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, # color_jitter=args.color_jitter, # auto_augment=args.aa, interpolation=args.train_interpolation or input_config['interpolation'], fill_color=input_config['fill_color'], mean=input_config['mean'], std=input_config['std'], num_workers=args.workers, distributed=args.distributed, pin_mem=args.pin_mem, anchor_labeler=labeler, transform_fn=transform_train_fn, collate_fn=collate_fn, ) if args.val_skip > 1: dataset_eval = SkipSubset(dataset_eval, args.val_skip) loader_eval = create_loader( dataset_eval, input_size=input_config['input_size'], batch_size=args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=input_config['interpolation'], fill_color=input_config['fill_color'], mean=input_config['mean'], std=input_config['std'], num_workers=args.workers, distributed=args.distributed, pin_mem=args.pin_mem, anchor_labeler=labeler, transform_fn=transform_eval_fn, collate_fn=collate_fn, ) evaluator = create_evaluator(args.dataset, loader_eval.dataset, distributed=args.distributed, pred_yxyx=False) return loader_train, loader_eval, evaluator
class EfficientDetTrain(pl.LightningModule): def __init__(self, model_name, num_classes, create_labeler=True): super().__init__() self.save_hyperparameters() config = get_efficientdet_config(model_name) self.model = EfficientDet(config) self.model.reset_head(num_classes=num_classes) self.num_levels = self.model.config.num_levels self.num_classes = self.model.config.num_classes self.anchors = Anchors.from_config(self.model.config) self.anchor_labeler = None if create_labeler: self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5) self.loss_fn = DetectionLoss(self.model.config) def forward(self, x): class_out, box_out = self.model(x) return class_out, box_out def training_step(self, batch, batch_idx): x, targets, idx = batch x = torch.stack(x, dim = 0) class_out, box_out = self.forward(x) bbox = [tar['bbox'].float() for tar in targets] clses = [tar['cls'].float() for tar in targets] target = {} target['bbox'] = bbox target['cls'] = clses if self.anchor_labeler is None: # target should contain pre-computed anchor labels if labeler not present in bench assert 'label_num_positives' in target cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)] box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)] num_positives = target['label_num_positives'] else: cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors( target['bbox'], target['cls']) loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives) return loss def validation_step(self, batch, batch_idx): x, targets, idx = batch x = torch.stack(x, dim = 0) class_out, box_out = self.forward(x) bbox = [tar['bbox'].float() for tar in targets] clses = [tar['cls'].float() for tar in targets] target = {} target['bbox'] = bbox target['cls'] = clses if self.anchor_labeler is None: # target should contain pre-computed anchor labels if labeler not present in bench assert 'label_num_positives' in target cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)] box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)] num_positives = target['label_num_positives'] else: cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors( target['bbox'], target['cls']) loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives) return {'val_loss': loss, "class_loss": class_loss, 'box_loss': box_loss} def configure_optimizers(self): # optimizer = torch.optim.Adam(self.parameters(), lr=1e-4) optimizer = torch.optim.SGD(self.parameters(), lr = 1e-4, momentum=0.9, weight_decay = 4e-5) return optimizer