Beispiel #1
0
class LvisSaver(object):
    def __init__(self, model, path):
        self.lvis = LVIS(
            '/scratch/users/zzweng/datasets/lvis/lvis_v0.5_val.json')
        self.model = model
        self.model.eval()
        self.path = path

    def save(self, k=50):
        # split the files into 50 chunks and process each concurrently
        #         self._save_part(0, 5)

        rng = np.linspace(0, 5000, k + 1, dtype=int)
        args = list(zip(rng[:-1], rng[1:]))
        with multiprocessing.Pool(processes=6) as pool:
            results = pool.starmap(self._save_part, args)
        print('Done')

    def _save_part(self, start, end):
        print('Getting features from {} to {}'.format(start, end))
        img_ids = self.lvis.get_img_ids()[start:end]
        feature_y = []
        feature_x = []
        for img_id in tqdm(img_ids):
            img = self.lvis.load_imgs([img_id])[0]
            I = io.imread(img['coco_url'])
            if len(I.shape) == 2: continue

            for ann_id in self.lvis.get_ann_ids(img_ids=[img_id]):
                ann = self.lvis.load_anns([ann_id])[0]
                b = np.array(ann['bbox']).astype(np.int)
                try:
                    #                     import ipdb as pdb
                    #                     pdb.set_trace()
                    I_masked = I * np.expand_dims(self.lvis.ann_to_mask(ann),
                                                  2)
                    patch = I_masked[b[1]:b[1] + b[3],
                                     b[0]:b[0] + b[2], :] / 255.
                    patch = cv2.resize(patch, (224, 224))
                    patch_tensor = torch.tensor(patch).float()
                    feat = self.model(
                        patch_tensor.view(1, *patch_tensor.shape).permute(
                            0, 3, 1, 2))[1].detach().numpy().flatten()
                    feature_x.append(feat)
                    feature_y.append(ann['category_id'])
                except:
                    print('skipping anns', b)

        feature_x_arr = np.stack(feature_x)
        feature_y_arr = np.array(feature_y)
        print(feature_x_arr.shape, feature_y_arr.shape)

        np.save(
            os.path.join(self.path, 'val_feats_{}_{}_x.npy'.format(start,
                                                                   end)),
            feature_x_arr)
        np.save(
            os.path.join(self.path, 'val_feats_{}_{}_y.npy'.format(start,
                                                                   end)),
            feature_y_arr)
Beispiel #2
0
def load_lvis_json(json_file, image_root, dataset_name=None):
    from lvis import LVIS

    json_file = PathManager.get_local_path(json_file)

    timer = Timer()
    lvis_api = LVIS(json_file)
    if timer.seconds() > 1:
        logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))

    if dataset_name is not None:
        meta = get_lvis_instances_meta(dataset_name)
        MetadataCatalog.get(dataset_name).set(**meta)

    img_ids = sorted(lvis_api.imgs.keys())
    imgs = lvis_api.load_imgs(img_ids)
    anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]

    ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
    assert len(set(ann_ids)) == len(ann_ids), \
        f"Annotation ids in '{json_file}' are not unique"

    imgs_anns = list(zip(imgs, anns))

    logger.info(f"Loaded {len(imgs_anns)} images in the LVIS format from {json_file}")

    dataset_dicts = []

    for (img_dict, anno_dict_list) in imgs_anns:
        record = {}
        file_name = img_dict["file_name"]
        if img_dict["file_name"].startswith("COCO"):
            file_name = file_name[-16:]
        record["file_name"] = os.path.join(image_root, file_name)
        record["height"] = img_dict["height"]
        record["width"] = img_dict["width"]
        record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
        record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
        image_id = record["image_id"] = img_dict["id"]

        objs = []
        for anno in anno_dict_list:
            assert anno["image_id"] == image_id

            obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
            obj["category_id"] = anno["category_id"] - 1
            segm = anno["segmentation"]
            valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
            assert len(segm) == len(valid_segm), \
                "Annotation contains an invalid polygon with < 3 points"
            assert len(segm) > 0

            obj["segmentation"] = segm
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)

    return dataset_dicts
class LVISDetection(VisionDataset):
    """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.

    """

    def __init__(
            self,
            root: str,
            annFile: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ) -> None:
        super(LVISDetection, self).__init__(root, transforms, transform, target_transform)
        self.lvis = LVIS(annFile)
        self.ids = list(sorted(self.lvis.imgs.keys()))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        id = self.ids[index]
        image = self._load_image(id)
        target = self._load_target(id)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target


    def __len__(self) -> int:
        return len(self.ids)

    def _load_image(self, id: int) -> Image.Image:
        path = self.lvis.load_imgs([id])[0]["coco_url"]
        path = path.split('/')[-2]+"/"+path.split('/')[-1]
        root = '/'.join(self.root.split('/')[:-1])
        return Image.open(os.path.join(root, path)).convert("RGB")

    def _load_target(self, id) -> List[Any]:
        return self.lvis.load_anns(self.lvis.get_ann_ids([id]))
Beispiel #4
0
    def _load_annotations(self, json_file, image_root):
        """
        Load a json file in LVIS's annotation format.
        Args:
            json_file (str): full path to the LVIS json annotation file.
            image_root (str): the directory where the images in this json file exists.
        Returns:
            list[dict]: a list of dicts in cvpods standard format. (See
            `Using Custom Datasets </tutorials/datasets.html>`_ )
        Notes:
            1. This function does not read the image files.
            The results do not have the "image" field.
        """
        from lvis import LVIS

        json_file = PathManager.get_local_path(json_file)

        timer = Timer()
        lvis_api = LVIS(json_file)
        if timer.seconds() > 1:
            logger.info("Loading {} takes {:.2f} seconds.".format(
                json_file, timer.seconds()))

        # sort indices for reproducible results
        img_ids = sorted(lvis_api.imgs.keys())
        # imgs is a list of dicts, each looks something like:
        # {'license': 4,
        #  'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
        #  'file_name': 'COCO_val2014_000000001268.jpg',
        #  'height': 427,
        #  'width': 640,
        #  'date_captured': '2013-11-17 05:57:24',
        #  'id': 1268}
        imgs = lvis_api.load_imgs(img_ids)
        # anns is a list[list[dict]], where each dict is an annotation
        # record for an object. The inner list enumerates the objects in an image
        # and the outer list enumerates over images. Example of anns[0]:
        # [{'segmentation': [[192.81,
        #     247.09,
        #     ...
        #     219.03,
        #     249.06]],
        #   'area': 1035.749,
        #   'image_id': 1268,
        #   'bbox': [192.81, 224.8, 74.73, 33.43],
        #   'category_id': 16,
        #   'id': 42986},
        #  ...]
        anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]

        # Sanity check that each annotation has a unique id
        ann_ids = [
            ann["id"] for anns_per_image in anns for ann in anns_per_image
        ]
        assert len(set(ann_ids)) == len(
            ann_ids), "Annotation ids in '{}' are not unique".format(json_file)

        imgs_anns = list(zip(imgs, anns))

        logger.info("Loaded {} images in the LVIS format from {}".format(
            len(imgs_anns), json_file))

        dataset_dicts = []
        for (img_dict, anno_dict_list) in imgs_anns:
            record = {}
            file_name = img_dict["file_name"]
            if img_dict["file_name"].startswith("COCO"):
                # Convert form the COCO 2014 file naming convention of
                # COCO_[train/val/test]2014_000000000000.jpg to the 2017 naming convention of
                # 000000000000.jpg (LVIS v1 will fix this naming issue)
                file_name = file_name[-16:]
            record["file_name"] = os.path.join(image_root, file_name)
            record["height"] = img_dict["height"]
            record["width"] = img_dict["width"]
            record["not_exhaustive_category_ids"] = img_dict.get(
                "not_exhaustive_category_ids", [])
            record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
            image_id = record["image_id"] = img_dict["id"]

            objs = []
            for anno in anno_dict_list:
                # Check that the image_id in this annotation is the same as
                # the image_id we're looking at.
                # This fails only when the data parsing logic or the annotation file is buggy.
                assert anno["image_id"] == image_id
                obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
                obj["category_id"] = anno[
                    "category_id"] - 1  # Convert 1-indexed to 0-indexed
                segm = anno["segmentation"]  # list[list[float]]
                # filter out invalid polygons (< 3 points)
                valid_segm = [
                    poly for poly in segm
                    if len(poly) % 2 == 0 and len(poly) >= 6
                ]
                assert len(segm) == len(
                    valid_segm
                ), "Annotation contains an invalid polygon with < 3 points"
                assert len(segm) > 0
                obj["segmentation"] = segm
                objs.append(obj)
            record["annotations"] = objs
            dataset_dicts.append(record)

        return dataset_dicts
Beispiel #5
0
class LVISV1Dataset(LVISDataset):

    CLASSES = (
        'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol',
        'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna',
        'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
        'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
        'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
        'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
        'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
        'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
        'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
        'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
        'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
        'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
        'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
        'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
        'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
        'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
        'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
        'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
        'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
        'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
        'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
        'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
        'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
        'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
        'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)',
        'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box',
        'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
        'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase',
        'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts',
        'bubble_gum', 'bucket', 'horse_buggy', 'bull', 'bulldog', 'bulldozer',
        'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn',
        'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card',
        'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
        'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
        'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
        'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar',
        'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup',
        'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
        'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
        'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
        'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
        'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower',
        'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone',
        'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier',
        'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard',
        'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime',
        'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
        'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
        'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
        'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
        'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine',
        'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock',
        'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster',
        'coat', 'coat_hanger', 'coatrack', 'c**k', 'cockroach',
        'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table',
        'coffeepot', 'coil', 'coin', 'colander', 'coleslaw',
        'coloring_material', 'combination_lock', 'pacifier', 'comic_book',
        'compass', 'computer_keyboard', 'condiment', 'cone', 'control',
        'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
        'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
        'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
        'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
        'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
        'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
        'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
        'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
        'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
        'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
        'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
        'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
        'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux',
        'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
        'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
        'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
        'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
        'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
        'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)',
        'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell',
        'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring',
        'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
        'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
        'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
        'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
        'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm',
        'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace',
        'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl',
        'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap',
        'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
        'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
        'folding_chair', 'food_processor', 'football_(American)',
        'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
        'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
        'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage',
        'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic',
        'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator',
        'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture',
        'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
        'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
        'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat',
        'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly',
        'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet',
        'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock',
        'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
        'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
        'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband',
        'headboard', 'headlight', 'headscarf', 'headset',
        'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
        'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
        'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
        'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
        'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
        'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
        'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
        'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
        'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
        'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
        'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
        'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
        'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
        'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
        'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
        'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce',
        'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
        'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
        'lizard', 'log', 'lollipop', 'speaker_(stero_equipment)', 'loveseat',
        'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
        'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger',
        'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato',
        'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox',
        'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine',
        'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone',
        'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror',
        'mitten', 'mixer_(kitchen_tool)', 'money',
        'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
        'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)',
        'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
        'music_stool', 'musical_instrument', 'nailfile', 'napkin',
        'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper',
        'newsstand', 'nightshirt', 'nosebag_(for_animals)',
        'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
        'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
        'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich',
        'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad',
        'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas',
        'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake',
        'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
        'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol',
        'parchment', 'parka', 'parking_meter', 'parrot',
        'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
        'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
        'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
        'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
        'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
        'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
        'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
        'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
        'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
        'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
        'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
        'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
        'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
        'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
        'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato',
        'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel',
        'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
        'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
        'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
        'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
        'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
        'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
        'recliner', 'record_player', 'reflector', 'remote_control',
        'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
        'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
        'rolling_pin', 'root_beer', 'router_(computer_equipment)',
        'rubber_band', 'runner_(carpet)', 'plastic_bag',
        'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
        'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
        'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
        'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
        'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
        'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
        'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
        'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
        'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
        'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
        'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
        'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
        'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
        'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
        'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
        'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
        'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
        'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
        'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
        'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
        'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
        'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
        'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
        'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer',
        'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer',
        'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign',
        'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl',
        'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses',
        'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband',
        'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword',
        'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
        'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
        'tambourine', 'army_tank', 'tank_(storage_vessel)',
        'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
        'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
        'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
        'telephone_pole', 'telephoto_lens', 'television_camera',
        'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
        'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
        'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil',
        'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven',
        'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush',
        'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel',
        'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light',
        'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline',
        'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle',
        'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat',
        'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
        'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
        'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
        'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
        'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
        'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
        'washbasin', 'automatic_washer', 'watch', 'water_bottle',
        'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
        'water_gun', 'water_scooter', 'water_ski', 'water_tower',
        'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
        'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
        'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
        'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
        'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
        'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
        'yoke_(animal_equipment)', 'zebra', 'zucchini')

    def load_annotations(self, ann_file):
        try:
            import lvis
            assert lvis.__version__ >= '10.5.3'
            from lvis import LVIS
        except AssertionError:
            raise AssertionError('Incompatible version of lvis is installed. '
                                 'Run pip uninstall lvis first. Then run pip '
                                 'install mmlvis to install open-mmlab forked '
                                 'lvis. ')
        except ImportError:
            raise ImportError('Package lvis is not installed. Please run pip '
                              'install mmlvis to install open-mmlab forked '
                              'lvis.')
        self.coco = LVIS(ann_file)
        assert not self.custom_classes, 'LVIS custom classes is not supported'
        self.cat_ids = self.coco.get_cat_ids()
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            # coco_url is used in LVISv1 instead of file_name
            # e.g. http://images.cocodataset.org/train2017/000000391895.jpg
            # train/val split in specified in url
            info['filename'] = info['coco_url'].replace(
                'http://images.cocodataset.org/', '')
            data_infos.append(info)
        return data_infos
Beispiel #6
0
class LVISV05Dataset(CocoDataset):

    CLASSES = (
        'acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
        'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
        'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
        'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
        'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
        'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
        'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
        'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
        'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
        'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
        'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
        'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
        'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead',
        'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
        'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can',
        'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench',
        'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder', 'binoculars',
        'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse',
        'birthday_cake', 'birthday_card', 'biscuit_(bread)', 'pirate_flag',
        'black_sheep', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp',
        'blinker', 'blueberry', 'boar', 'gameboard', 'boat', 'bobbin',
        'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet',
        'book', 'book_bag', 'bookcase', 'booklet', 'bookmark',
        'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet',
        'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl',
        'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
        'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
        'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
        'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
        'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
        'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
        'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
        'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
        'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
        'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
        'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
        'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
        'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
        'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
        'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
        'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
        'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
        'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
        'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
        'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
        'champagne', 'chandelier', 'chap', 'checkbook', 'checkerboard',
        'cherry', 'chessboard', 'chest_of_drawers_(furniture)',
        'chicken_(animal)', 'chicken_wire', 'chickpea', 'Chihuahua',
        'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
        'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
        'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
        'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
        'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
        'clementine', 'clip', 'clipboard', 'clock', 'clock_tower',
        'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
        'coat_hanger', 'coatrack', 'c**k', 'coconut', 'coffee_filter',
        'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin',
        'colander', 'coleslaw', 'coloring_material', 'combination_lock',
        'pacifier', 'comic_book', 'computer_keyboard', 'concrete_mixer',
        'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cookie',
        'cookie_jar', 'cooking_utensil', 'cooler_(for_food)',
        'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn',
        'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
        'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
        'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
        'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
        'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
        'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
        'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
        'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
        'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
        'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
        'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux',
        'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
        'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
        'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
        'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
        'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
        'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
        'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
        'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
        'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
        'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
        'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
        'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
        'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
        'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm',
        'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace',
        'fireplug', 'fish', 'fish_(food)', 'fishbowl', 'fishing_boat',
        'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flash',
        'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
        'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
        'food_processor', 'football_(American)', 'football_helmet',
        'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
        'freshener', 'frisbee', 'frog', 'fruit_juice', 'fruit_salad',
        'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage',
        'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic',
        'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
        'gift_wrap', 'ginger', 'giraffe', 'cincture',
        'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
        'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
        'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
        'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
        'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
        'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
        'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
        'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
        'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
        'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
        'headband', 'headboard', 'headlight', 'headscarf', 'headset',
        'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
        'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
        'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
        'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
        'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
        'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
        'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
        'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
        'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
        'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
        'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
        'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
        'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
        'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
        'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
        'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
        'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
        'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
        'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
        'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
        'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine',
        'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
        'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
        'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
        'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
        'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
        'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
        'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
        'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
        'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
        'mound_(baseball)', 'mouse_(animal_rodent)',
        'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
        'music_stool', 'musical_instrument', 'nailfile', 'nameplate', 'napkin',
        'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newsstand',
        'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)',
        'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
        'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
        'orange_(fruit)', 'orange_juice', 'oregano', 'ostrich', 'ottoman',
        'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
        'padlock', 'paintbox', 'paintbrush', 'painting', 'pajamas', 'palette',
        'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
        'papaya', 'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
        'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
        'parchment', 'parka', 'parking_meter', 'parrot',
        'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
        'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
        'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
        'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
        'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
        'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
        'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
        'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
        'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
        'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
        'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
        'plate', 'platter', 'playing_card', 'playpen', 'pliers',
        'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
        'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
        'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
        'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato',
        'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'printer',
        'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding',
        'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet',
        'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car',
        'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft',
        'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
        'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
        'recliner', 'record_player', 'red_cabbage', 'reflector',
        'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
        'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
        'Rollerblade', 'rolling_pin', 'root_beer',
        'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
        'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
        'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
        'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
        'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
        'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
        'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
        'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
        'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
        'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
        'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
        'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
        'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart',
        'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head',
        'shower_curtain', 'shredder_(for_paper)', 'sieve', 'signboard', 'silo',
        'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka',
        'ski_pole', 'skirt', 'sled', 'sleeping_bag', 'sling_(bandage)',
        'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
        'snowmobile', 'soap', 'soccer_ball', 'sock', 'soda_fountain',
        'carbonated_water', 'sofa', 'softball', 'solar_array', 'sombrero',
        'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk',
        'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear',
        'spectacles', 'spice_rack', 'spider', 'sponge', 'spoon', 'sportswear',
        'spotlight', 'squirrel', 'stapler_(stapling_machine)', 'starfish',
        'statue_(sculpture)', 'steak_(food)', 'steak_knife',
        'steamer_(kitchen_appliance)', 'steering_wheel', 'stencil',
        'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer',
        'stirrup', 'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light',
        'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
        'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
        'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
        'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
        'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
        'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
        'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
        'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
        'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
        'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
        'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
        'telephone_pole', 'telephoto_lens', 'television_camera',
        'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
        'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
        'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil',
        'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven',
        'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush',
        'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel',
        'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light',
        'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline',
        'tray', 'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
        'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
        'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
        'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
        'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
        'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
        'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
        'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
        'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
        'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
        'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski',
        'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam',
        'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair',
        'whipped_cream', 'whiskey', 'whistle', 'wick', 'wig', 'wind_chime',
        'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock',
        'wine_bottle', 'wine_bucket', 'wineglass', 'wing_chair',
        'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath',
        'wrench', 'wristband', 'wristlet', 'yacht', 'yak', 'yogurt',
        'yoke_(animal_equipment)', 'zebra', 'zucchini')

    def load_annotations(self, ann_file):
        """Load annotation from lvis style annotation file.

        Args:
            ann_file (str): Path of annotation file.

        Returns:
            list[dict]: Annotation info from LVIS api.
        """

        try:
            import lvis
            assert lvis.__version__ >= '10.5.3'
            from lvis import LVIS
        except AssertionError:
            raise AssertionError('Incompatible version of lvis is installed. '
                                 'Run pip uninstall lvis first. Then run pip '
                                 'install mmlvis to install open-mmlab forked '
                                 'lvis. ')
        except ImportError:
            raise ImportError('Package lvis is not installed. Please run pip '
                              'install mmlvis to install open-mmlab forked '
                              'lvis.')
        self.coco = LVIS(ann_file)
        assert not self.custom_classes, 'LVIS custom classes is not supported'
        self.cat_ids = self.coco.get_cat_ids()
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            if info['file_name'].startswith('COCO'):
                # Convert form the COCO 2014 file naming convention of
                # COCO_[train/val/test]2014_000000000000.jpg to the 2017
                # naming convention of 000000000000.jpg
                # (LVIS v1 will fix this naming issue)
                info['filename'] = info['file_name'][-16:]
            else:
                info['filename'] = info['file_name']
            data_infos.append(info)
        return data_infos

    def evaluate(self,
                 results,
                 metric='bbox',
                 logger=None,
                 jsonfile_prefix=None,
                 classwise=False,
                 proposal_nums=(100, 300, 1000),
                 iou_thrs=np.arange(0.5, 0.96, 0.05)):
        """Evaluation in LVIS protocol.

        Args:
            results (list[list | tuple]): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated. Options are
                'bbox', 'segm', 'proposal', 'proposal_fast'.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            jsonfile_prefix (str | None):
            classwise (bool): Whether to evaluating the AP for each class.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thrs (Sequence[float]): IoU threshold used for evaluating
                recalls. If set to a list, the average recall of all IoUs will
                also be computed. Default: 0.5.

        Returns:
            dict[str, float]: LVIS style metrics.
        """

        try:
            import lvis
            assert lvis.__version__ >= '10.5.3'
            from lvis import LVISResults, LVISEval
        except AssertionError:
            raise AssertionError('Incompatible version of lvis is installed. '
                                 'Run pip uninstall lvis first. Then run pip '
                                 'install mmlvis to install open-mmlab forked '
                                 'lvis. ')
        except ImportError:
            raise ImportError('Package lvis is not installed. Please run pip '
                              'install mmlvis to install open-mmlab forked '
                              'lvis.')
        assert isinstance(results, list), 'results must be a list'
        assert len(results) == len(self), (
            'The length of results is not equal to the dataset len: {} != {}'.
            format(len(results), len(self)))

        metrics = metric if isinstance(metric, list) else [metric]
        allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError('metric {} is not supported'.format(metric))

        if jsonfile_prefix is None:
            tmp_dir = tempfile.TemporaryDirectory()
            jsonfile_prefix = osp.join(tmp_dir.name, 'results')
        else:
            tmp_dir = None
        result_files = self.results2json(results, jsonfile_prefix)

        eval_results = {}
        # get original api
        lvis_gt = self.coco
        for metric in metrics:
            msg = 'Evaluating {}...'.format(metric)
            if logger is None:
                msg = '\n' + msg
            print_log(msg, logger=logger)

            if metric == 'proposal_fast':
                ar = self.fast_eval_recall(results,
                                           proposal_nums,
                                           iou_thrs,
                                           logger='silent')
                log_msg = []
                for i, num in enumerate(proposal_nums):
                    eval_results['AR@{}'.format(num)] = ar[i]
                    log_msg.append('\nAR@{}\t{:.4f}'.format(num, ar[i]))
                log_msg = ''.join(log_msg)
                print_log(log_msg, logger=logger)
                continue

            if metric not in result_files:
                raise KeyError('{} is not in results'.format(metric))
            try:
                lvis_dt = LVISResults(lvis_gt, result_files[metric])
            except IndexError:
                print_log('The testing results of the whole dataset is empty.',
                          logger=logger,
                          level=logging.ERROR)
                break

            iou_type = 'bbox' if metric == 'proposal' else metric
            lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type)
            lvis_eval.params.imgIds = self.img_ids
            if metric == 'proposal':
                lvis_eval.params.useCats = 0
                lvis_eval.params.maxDets = list(proposal_nums)
                lvis_eval.evaluate()
                lvis_eval.accumulate()
                lvis_eval.summarize()
                for k, v in lvis_eval.get_results().items():
                    if k.startswith('AR'):
                        val = float('{:.3f}'.format(float(v)))
                        eval_results[k] = val
            else:
                lvis_eval.evaluate()
                lvis_eval.accumulate()
                lvis_eval.summarize()
                lvis_results = lvis_eval.get_results()
                if classwise:  # Compute per-category AP
                    # Compute per-category AP
                    # from https://github.com/facebookresearch/detectron2/
                    precisions = lvis_eval.eval['precision']
                    # precision: (iou, recall, cls, area range, max dets)
                    assert len(self.cat_ids) == precisions.shape[2]

                    results_per_category = []
                    for idx, catId in enumerate(self.cat_ids):
                        # area range index 0: all area ranges
                        # max dets index -1: typically 100 per image
                        nm = self.coco.load_cats(catId)[0]
                        precision = precisions[:, :, idx, 0, -1]
                        precision = precision[precision > -1]
                        if precision.size:
                            ap = np.mean(precision)
                        else:
                            ap = float('nan')
                        results_per_category.append(
                            (f'{nm["name"]}', f'{float(ap):0.3f}'))

                    num_columns = min(6, len(results_per_category) * 2)
                    results_flatten = list(
                        itertools.chain(*results_per_category))
                    headers = ['category', 'AP'] * (num_columns // 2)
                    results_2d = itertools.zip_longest(*[
                        results_flatten[i::num_columns]
                        for i in range(num_columns)
                    ])
                    table_data = [headers]
                    table_data += [result for result in results_2d]
                    table = AsciiTable(table_data)
                    print_log('\n' + table.table, logger=logger)

                for k, v in lvis_results.items():
                    if k.startswith('AP'):
                        key = '{}_{}'.format(metric, k)
                        val = float('{:.3f}'.format(float(v)))
                        eval_results[key] = val
                ap_summary = ' '.join([
                    '{}:{:.3f}'.format(k, float(v))
                    for k, v in lvis_results.items() if k.startswith('AP')
                ])
                eval_results['{}_mAP_copypaste'.format(metric)] = ap_summary
            lvis_eval.print_results()
        if tmp_dir is not None:
            tmp_dir.cleanup()
        return eval_results
Beispiel #7
0
def load_filtered_lvis_json(json_file,
                            image_root,
                            metadata,
                            dataset_name=None):
    """
    Load a json file in LVIS's annotation format.
    Args:
        json_file (str): full path to the LVIS json annotation file.
        image_root (str): the directory where the images in this json file exists.
        metadata: meta data associated with dataset_name
        dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
            If provided, this function will put "thing_classes" into the metadata
            associated with this dataset.
    Returns:
        list[dict]: a list of dicts in Detectron2 standard format. (See
        `Using Custom Datasets </tutorials/datasets.html>`_ )
    Notes:
        1. This function does not read the image files.
           The results do not have the "image" field.
    """
    from lvis import LVIS

    json_file = PathManager.get_local_path(json_file)

    timer = Timer()
    lvis_api = LVIS(json_file)
    if timer.seconds() > 1:
        logger.info("Loading {} takes {:.2f} seconds.".format(
            json_file, timer.seconds()))

    if dataset_name is not None and "train" in dataset_name:
        assert global_cfg.MODEL.ROI_HEADS.NUM_CLASSES == len(
            metadata["thing_classes"]
        ), "NUM_CLASSES should match number of categories: ALL=1230, NOVEL=454"

    # sort indices for reproducible results
    img_ids = sorted(list(lvis_api.imgs.keys()))
    imgs = lvis_api.load_imgs(img_ids)
    anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]

    # Sanity check that each annotation has a unique id
    ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
    assert len(set(ann_ids)) == len(
        ann_ids), "Annotation ids in '{}' are not unique".format(json_file)

    imgs_anns = list(zip(imgs, anns))

    logger.info("Loaded {} images in the LVIS format from {}".format(
        len(imgs_anns), json_file))

    dataset_dicts = []

    for (img_dict, anno_dict_list) in imgs_anns:
        record = {}
        file_name = img_dict["file_name"]
        if img_dict["file_name"].startswith("COCO"):
            file_name = file_name[-16:]
        record["file_name"] = os.path.join(image_root, file_name)
        record["height"] = img_dict["height"]
        record["width"] = img_dict["width"]
        record["not_exhaustive_category_ids"] = img_dict.get(
            "not_exhaustive_category_ids", [])
        record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
        image_id = record["image_id"] = img_dict["id"]

        objs = []
        for anno in anno_dict_list:
            # Check that the image_id in this annotation is the same as
            # the image_id we're looking at.
            assert anno["image_id"] == image_id
            obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
            if global_cfg.MODEL.ROI_HEADS.NUM_CLASSES == 454:
                # Novel classes only
                if anno["category_id"] - 1 not in LVIS_CATEGORIES_NOVEL_IDS:
                    continue
                obj["category_id"] = metadata["class_mapping"][
                    anno["category_id"] - 1]
            else:
                # Convert 1-indexed to 0-indexed
                obj["category_id"] = anno["category_id"] - 1
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)

    return dataset_dicts
def load_lvis_json(json_file, image_root, dataset_name=None):
    """
    Load a json file in LVIS's annotation format.

    Args:
        json_file (str): full path to the LVIS json annotation file.
        image_root (str): the directory where the images in this json file exists.
        dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
            If provided, this function will put "thing_classes" into the metadata
            associated with this dataset.

    Returns:
        list[dict]: a list of dicts in Detectron2 standard format. (See
        `Using Custom Datasets </tutorials/datasets.html>`_ )

    Notes:
        1. This function does not read the image files.
           The results do not have the "image" field.
    """
    from lvis import LVIS

    json_file = PathManager.get_local_path(json_file)

    timer = Timer()
    lvis_api = LVIS(json_file)
    if timer.seconds() > 1:
        logger.info("Loading {} takes {:.2f} seconds.".format(
            json_file, timer.seconds()))

    if dataset_name is not None:
        meta = get_lvis_instances_meta(dataset_name)
        MetadataCatalog.get(dataset_name).set(**meta)

    # sort indices for reproducible results
    img_ids = sorted(lvis_api.imgs.keys())
    # imgs is a list of dicts, each looks something like:
    # {'license': 4,
    #  'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
    #  'file_name': 'COCO_val2014_000000001268.jpg',
    #  'height': 427,
    #  'width': 640,
    #  'date_captured': '2013-11-17 05:57:24',
    #  'id': 1268}
    imgs = lvis_api.load_imgs(img_ids)
    # anns is a list[list[dict]], where each dict is an annotation
    # record for an object. The inner list enumerates the objects in an image
    # and the outer list enumerates over images. Example of anns[0]:
    # [{'segmentation': [[192.81,
    #     247.09,
    #     ...
    #     219.03,
    #     249.06]],
    #   'area': 1035.749,
    #   'image_id': 1268,
    #   'bbox': [192.81, 224.8, 74.73, 33.43],
    #   'category_id': 16,
    #   'id': 42986},
    #  ...]
    anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]

    # Sanity check that each annotation has a unique id
    ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
    assert len(set(ann_ids)) == len(
        ann_ids), "Annotation ids in '{}' are not unique".format(json_file)

    imgs_anns = list(zip(imgs, anns))

    logger.info("Loaded {} images in the LVIS format from {}".format(
        len(imgs_anns), json_file))

    def get_file_name(img_root, img_dict):
        # Determine the path including the split folder ("train2017", "val2017", "test2017") from
        # the coco_url field. Example:
        #   'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
        split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
        return os.path.join(img_root + split_folder, file_name)

    dataset_dicts = []

    for (img_dict, anno_dict_list) in imgs_anns:
        record = {}
        record["file_name"] = get_file_name(image_root, img_dict)
        record["height"] = img_dict["height"]
        record["width"] = img_dict["width"]
        record["not_exhaustive_category_ids"] = img_dict.get(
            "not_exhaustive_category_ids", [])
        record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
        image_id = record["image_id"] = img_dict["id"]

        objs = []
        for anno in anno_dict_list:
            # Check that the image_id in this annotation is the same as
            # the image_id we're looking at.
            # This fails only when the data parsing logic or the annotation file is buggy.
            assert anno["image_id"] == image_id
            obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
            obj["category_id"] = anno[
                "category_id"] - 1  # Convert 1-indexed to 0-indexed
            segm = anno["segmentation"]  # list[list[float]]
            # filter out invalid polygons (< 3 points)
            valid_segm = [
                poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6
            ]
            assert len(segm) == len(
                valid_segm
            ), "Annotation contains an invalid polygon with < 3 points"
            assert len(segm) > 0
            obj["segmentation"] = segm
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)

    return dataset_dicts
Beispiel #9
0
def extended_lvis_load(json_file, image_root, dataset_name=None):
    """
    Load a json file in LVIS's annotation format.

    Args:
        json_file (str): full path to the LVIS json annotation file.
        image_root (str): the directory where the images in this json file exists.
        dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
            If provided, this function will put "thing_classes" into the metadata
            associated with this dataset.

    Returns:
        list[dict]: a list of dicts in "Detectron2 Dataset" format. (See DATASETS.md)

    Notes:
        1. This function does not read the image files.
           The results do not have the "image" field.
    """
    from lvis import LVIS

    json_file = _cache_json_file(json_file)

    timer = Timer()
    lvis_api = LVIS(json_file)
    if timer.seconds() > 1:
        logger.info("Loading {} takes {:.2f} seconds.".format(
            json_file, timer.seconds()))

    # sort indices for reproducible results
    img_ids = sorted(list(lvis_api.imgs.keys()))
    # imgs is a list of dicts, each looks something like:
    # {'license': 4,
    #  'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
    #  'file_name': 'COCO_val2014_000000001268.jpg',
    #  'height': 427,
    #  'width': 640,
    #  'date_captured': '2013-11-17 05:57:24',
    #  'id': 1268}
    imgs = lvis_api.load_imgs(img_ids)
    # anns is a list[list[dict]], where each dict is an annotation
    # record for an object. The inner list enumerates the objects in an image
    # and the outer list enumerates over images. Example of anns[0]:
    # [{'segmentation': [[192.81,
    #     247.09,
    #     ...
    #     219.03,
    #     249.06]],
    #   'area': 1035.749,
    #   'image_id': 1268,
    #   'bbox': [192.81, 224.8, 74.73, 33.43],
    #   'category_id': 16,
    #   'id': 42986},
    #  ...]
    anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]

    # Sanity check that each annotation has a unique id
    ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
    assert len(set(ann_ids)) == len(
        ann_ids), "Annotation ids in '{}' are not unique".format(json_file)

    imgs_anns = list(zip(imgs, anns))

    logger.info("Loaded {} images in the LVIS format from {}".format(
        len(imgs_anns), json_file))

    dataset_dicts = []

    count_ignore_image_root_warning = 0
    for (img_dict, anno_dict_list) in imgs_anns:
        record = {}
        if "://" not in img_dict["file_name"]:
            file_name = img_dict["file_name"]
            if img_dict["file_name"].startswith("COCO"):
                # Convert form the COCO 2014 file naming convention of
                # COCO_[train/val/test]2014_000000000000.jpg to the 2017 naming
                # convention of 000000000000.jpg (LVIS v1 will fix this naming issue)
                file_name = file_name[-16:]
            record["file_name"] = os.path.join(image_root, file_name)
        else:
            if image_root is not None:
                count_ignore_image_root_warning += 1
                if count_ignore_image_root_warning == 1:
                    logger.warning(
                        ("Found '://' in file_name: {}, ignore image_root: {}"
                         "(logged once per dataset).").format(
                             img_dict["file_name"], image_root))
            record["file_name"] = img_dict["file_name"]
        record["height"] = img_dict["height"]
        record["width"] = img_dict["width"]
        record["not_exhaustive_category_ids"] = img_dict.get(
            "not_exhaustive_category_ids", [])
        record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
        image_id = record["image_id"] = img_dict["id"]

        objs = []
        for anno in anno_dict_list:
            # Check that the image_id in this annotation is the same as
            # the image_id we're looking at.
            # Fails only when the data parsing logic or the annotation file is buggy.
            assert anno["image_id"] == image_id
            obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
            obj["category_id"] = (anno["category_id"] - 1
                                  )  # Convert 1-indexed to 0-indexed
            segm = anno["segmentation"]
            # filter out invalid polygons (< 3 points)
            valid_segm = [
                poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6
            ]
            assert len(segm) == len(
                valid_segm
            ), "Annotation contains an invalid polygon with < 3 points"
            assert len(segm) > 0
            obj["segmentation"] = segm
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)

    if dataset_name:
        meta = MetadataCatalog.get(dataset_name)
        meta.thing_classes = get_extended_lvis_instances_meta(
            lvis_api)["thing_classes"]

    return dataset_dicts
Beispiel #10
0
class LVISEvalCustom(LVISEval):
    def __init__(self, lvis_gt, lvis_dt, iou_type="segm"):
        """Constructor for LVISEval.
        Args:
            lvis_gt (LVIS class instance, or str containing path of annotation file)
            lvis_dt (LVISResult class instance, or str containing path of result file,
            or list of dict)
            iou_type (str): segm or bbox evaluation
        """
        self.logger = logging.getLogger(__name__)

        if iou_type not in ["bbox", "segm"]:
            raise ValueError("iou_type: {} is not supported.".format(iou_type))

        if isinstance(lvis_gt, LVIS):
            self.lvis_gt = lvis_gt
        elif isinstance(lvis_gt, str):
            self.lvis_gt = LVIS(lvis_gt)
        else:
            raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt))

        if isinstance(lvis_dt, LVISResults):
            self.lvis_dt = lvis_dt
        elif isinstance(lvis_dt, (str, list)):
            # set max_dets=-1 to avoid ignoring
            self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt, max_dets=-1)
        else:
            raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt))

        # per-image per-category evaluation results
        self.eval_imgs = defaultdict(list)
        self.eval = {}  # accumulated evaluation results
        self._gts = defaultdict(list)  # gt for evaluation
        self._dts = defaultdict(list)  # dt for evaluation
        self.params = ParamsCustom(iou_type=iou_type)  # parameters
        self.results = OrderedDict()
        self.ious = {}  # ious between all gts and dts

        self.params.img_ids = sorted(self.lvis_gt.get_img_ids())
        self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids())

    def _prepare(self):
        """Prepare self._gts and self._dts for evaluation based on params."""
        cat_ids = self.params.cat_ids if self.params.cat_ids else None

        gts = self.lvis_gt.load_anns(
            self.lvis_gt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)
        )
        dts = self.lvis_dt.load_anns(
            self.lvis_dt.get_ann_ids(img_ids=self.params.img_ids,
                                     cat_ids=None if self.params.use_proposal else cat_ids)
        )
        # convert ground truth to mask if iou_type == 'segm'
        if self.params.iou_type == "segm":
            self._to_mask(gts, self.lvis_gt)
            self._to_mask(dts, self.lvis_dt)

        # set ignore flag
        for gt in gts:
            if "ignore" not in gt:
                gt["ignore"] = 0

        for gt in gts:
            self._gts[gt["image_id"], gt["category_id"]].append(gt)

        # For federated dataset evaluation we will filter out all dt for an
        # image which belong to categories not present in gt and not present in
        # the negative list for an image. In other words detector is not penalized
        # for categories about which we don't have gt information about their
        # presence or absence in an image.
        img_data = self.lvis_gt.load_imgs(ids=self.params.img_ids)
        # per image map of categories not present in image
        img_nl = {d["id"]: d["not_exhaustive_category_ids"] for d in img_data}
        # per image list of categories present in image
        img_pl = defaultdict(set)
        for ann in gts:
            img_pl[ann["image_id"]].add(ann["category_id"])
        # per image map of categoires which have missing gt. For these
        # categories we don't penalize the detector for flase positives.
        self.img_nel = {d["id"]: d["not_exhaustive_category_ids"] for d in img_data}

        for dt in dts:
            img_id, cat_id = dt["image_id"], dt["category_id"]
            if self.params.use_proposal:  # for proposal eval
                for cat_id in img_pl[img_id]:
                    dt['category_id'] = cat_id
                    self._dts[img_id, cat_id].append(dt)
                continue
            elif cat_id not in img_nl[img_id] and cat_id not in img_pl[img_id]:
                continue
            self._dts[img_id, cat_id].append(dt)

        self.freq_groups = self._prepare_freq_group()

    def _summarize(
        self, summary_type, iou_thr=None, area_rng="all", freq_group_idx=None
    ):
        aidx = [
            idx
            for idx, _area_rng in enumerate(self.params.area_rng_lbl)
            if _area_rng == area_rng
        ]

        if summary_type == 'ap':
            s = self.eval["precision"]
            if iou_thr is not None:
                tidx = np.where(iou_thr == self.params.iou_thrs)[0]
                s = s[tidx]
            if freq_group_idx is not None:
                s = s[:, :, self.freq_groups[freq_group_idx], aidx]
            else:
                s = s[:, :, :, aidx]
        else:
            s = self.eval["recall"]
            if iou_thr is not None:
                tidx = np.where(iou_thr == self.params.iou_thrs)[0]
                s = s[tidx]
            if freq_group_idx is not None:  # add freq_group for recall
                s = s[:, self.freq_groups[freq_group_idx], aidx]
            else:
                s = s[:, :, aidx]

        if len(s[s > -1]) == 0:
            mean_s = -1
        else:
            mean_s = np.mean(s[s > -1])
        return mean_s

    def summarize(self):
        """Compute and display summary metrics for evaluation results."""
        if not self.eval:
            raise RuntimeError("Please run accumulate() first.")

        max_dets = self.params.max_dets

        self.results["AP"]   = self._summarize('ap')
        self.results["AP50"] = self._summarize('ap', iou_thr=0.50)
        self.results["AP75"] = self._summarize('ap', iou_thr=0.75)
        self.results["APs"]  = self._summarize('ap', area_rng="small")
        self.results["APm"]  = self._summarize('ap', area_rng="medium")
        self.results["APl"]  = self._summarize('ap', area_rng="large")
        self.results["APr"]  = self._summarize('ap', freq_group_idx=0)
        self.results["APc"]  = self._summarize('ap', freq_group_idx=1)
        self.results["APf"]  = self._summarize('ap', freq_group_idx=2)

        key = "AR@{}".format(max_dets)
        self.results[key] = self._summarize('ar')

        for area_rng in ["small", "medium", "large"]:
            key = "AR{}@{}".format(area_rng[0], max_dets)
            self.results[key] = self._summarize('ar', area_rng=area_rng)
        # add freq_group for recall
        for idx, freq_group in enumerate(self.params.img_count_lbl):
            key = "AR{}@{}".format(freq_group[0], max_dets)
            self.results[key] = self._summarize('ar', freq_group_idx=idx)
Beispiel #11
0
class LvisDataset(DownloadableDataset):
    """ LVIS PyTorch Object Detection Dataset """
    def __init__(
        self,
        root: Union[str, Path] = None,
        *,
        train=True,
        transform=None,
        loader=default_loader,
        download=True,
        lvis_api=None,
        img_ids: List[int] = None,
    ):
        """
        Creates an instance of the LVIS dataset.

        :param root: The directory where the dataset can be found or downloaded.
            Defaults to None, which means that the default location for
            "lvis" will be used.
        :param train: If True, the training set will be returned. If False,
            the test set will be returned.
        :param transform: The transformation to apply to (img, annotations)
            values.
        :param loader: The image loader to use.
        :param download: If True, the dataset will be downloaded if needed.
        :param lvis_api: An instance of the LVIS class (from the lvis-api) to
            use. Defaults to None, which means that annotations will be loaded
            from the annotation json found in the root directory.
        :param img_ids: A list representing a subset of images to use. Defaults
            to None, which means that the dataset will contain all images
            in the LVIS dataset.
        """

        if root is None:
            root = default_dataset_location("lvis")

        self.train = train  # training set or test set
        self.transform = transform
        self.loader = loader
        self.bbox_crop = True
        self.img_ids = img_ids

        self.targets = None
        self.lvis_api = lvis_api

        super(LvisDataset, self).__init__(root,
                                          download=download,
                                          verbose=True)

        self._load_dataset()

    def _download_dataset(self) -> None:
        data2download = lvis_archives

        for name, url, checksum in data2download:
            if self.verbose:
                print("Downloading " + name + "...")

            result_file = self._download_file(url, name, checksum)
            if self.verbose:
                print("Download completed. Extracting...")

            self._extract_archive(result_file)
            if self.verbose:
                print("Extraction completed!")

    def _load_metadata(self) -> bool:
        must_load_api = self.lvis_api is None
        must_load_img_ids = self.img_ids is None
        try:
            # Load metadata
            if must_load_api:
                if self.train:
                    ann_json_path = str(self.root / "lvis_v1_train.json")
                else:
                    ann_json_path = str(self.root / "lvis_v1_val.json")

                self.lvis_api = LVIS(ann_json_path)

            if must_load_img_ids:
                self.img_ids = list(sorted(self.lvis_api.get_img_ids()))

            self.targets = LVISDetectionTargets(self.lvis_api, self.img_ids)

            # Try loading an image
            if len(self.img_ids) > 0:
                img_id = self.img_ids[0]
                img_dict: LVISImgEntry = \
                    self.lvis_api.load_imgs(ids=[img_id])[0]
                assert self._load_img(img_dict) is not None
        except BaseException:
            if must_load_api:
                self.lvis_api = None
            if must_load_img_ids:
                self.img_ids = None

            self.targets = None
            raise

        return True

    def _download_error_message(self) -> str:
        return (
            "[LVIS] Error downloading the dataset. Consider "
            "downloading it manually at: https://www.lvisdataset.org/dataset"
            " and placing it in: " + str(self.root))

    def __getitem__(self, index):
        """
        Loads an instance given its index.

        :param index: The index of the instance to retrieve.

        :return: a (sample, target) tuple where the target is a
            torchvision-style annotation for object detection
            https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
        """
        img_id = self.img_ids[index]
        img_dict: LVISImgEntry = self.lvis_api.load_imgs(ids=[img_id])[0]
        annotation_dicts = self.targets[index]

        # Transform from LVIS dictionary to torchvision-style target
        num_objs = len(annotation_dicts)

        boxes = []
        labels = []
        for i in range(num_objs):
            xmin = annotation_dicts[i]['bbox'][0]
            ymin = annotation_dicts[i]['bbox'][1]
            xmax = xmin + annotation_dicts[i]['bbox'][2]
            ymax = ymin + annotation_dicts[i]['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(annotation_dicts[i]['category_id'])

        if len(boxes) > 0:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
        else:
            boxes = torch.empty((0, 4), dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        image_id = torch.tensor([img_id])
        areas = []
        for i in range(num_objs):
            areas.append(annotation_dicts[i]['area'])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.zeros((num_objs, ), dtype=torch.int64)

        target = dict()
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = areas
        target["iscrowd"] = iscrowd

        img = self._load_img(img_dict)

        if self.transform is not None:
            img, target = self.transform(img, target)

        return img, target

    def __len__(self):
        return len(self.img_ids)

    def _load_img(self, img_dict: "LVISImgEntry"):
        coco_url = img_dict['coco_url']
        splitted_url = coco_url.split('/')
        img_path = splitted_url[-2] + '/' + splitted_url[-1]
        final_path = self.root / img_path  # <root>/train2017/<img_id>.jpg
        return self.loader(str(final_path))
Beispiel #12
0
class MiniLvisDataSet(LvisDataSet):
    """Only contains the same 80 classes as COCO dataset"""
    # TODO: fix the lvis-coco classes
    LVIS_TO_COCO = {
        'computer_mouse': 'mouse',
        'cellphone': 'cell_phone',
        'hair_dryer': 'hair_drier',
        'sausage': 'hot_dog',  # not sure!!!
        'laptop_computer': 'laptop',
        'microwave_oven': 'microwave',
        'toaster_oven': 'oven',
        'flower_arrangement': 'potted_plant',  # not sure!!
        'remote_control': 'remote',
        'ski': 'skis',
        'baseball': 'sports_ball',  # too many types of balls
        'wineglass': 'wine_glass',
    }

    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        COCO_CLASSES = sorted(list(CocoDataset.CLASSES))
        self.synonyms_classes = [
            value['synonyms'] for value in self.lvis.cats.values()
        ]
        self.cat_ids = []
        self.CLASSES = []
        self.not_in_classes = COCO_CLASSES
        for id in self.lvis.get_cat_ids():
            for name in self.lvis.cats[id]['synonyms']:
                if name in self.LVIS_TO_COCO:
                    self.cat_ids.append(id)
                    self.CLASSES.append(name)
                    self.not_in_classes.remove(self.LVIS_TO_COCO[name])
                    break
                elif '(' not in name and name in COCO_CLASSES:
                    self.cat_ids.append(id)
                    self.CLASSES.append(name)
                    self.not_in_classes.remove(name)
                    break
                elif '_' in name:
                    new_name = name.split('_(')[0]
                    if new_name in COCO_CLASSES:
                        self.cat_ids.append(id)
                        self.CLASSES.append(name)
                        self.not_in_classes.remove(new_name)
                        break
        data_dir = osp.dirname(ann_file)
        with open(osp.join(data_dir, 'synonyms_classes.json'), 'w') as f:
            f.write(json.dumps(self.synonyms_classes, indent=2))
        with open(osp.join(data_dir, 'not_in_classes.json'), 'w') as f:
            f.write(json.dumps(self.not_in_classes, indent=2))
        with open(osp.join(data_dir, 'coco_classes.json'), 'w') as f:
            f.write(json.dumps(COCO_CLASSES, indent=2))
        with open(osp.join(data_dir, 'lvis_coco_classes.json'), 'w') as f:
            f.write(json.dumps(self.CLASSES, indent=2))
        self.CLASSES = tuple(self.CLASSES)
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.CLASSES = CocoDataset.CLASSES
        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx):
        img_id = self.img_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(ann_info, self.with_mask)

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds
Beispiel #13
0
class LvisDataset1(CustomDataset):
    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        self.full_cat_ids = self.lvis.get_cat_ids()
        self.full_cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.full_cat_ids)
        }

        self.CLASSES = tuple(
            [item['name'] for item in self.lvis.dataset['categories']])
        self.cat_ids = self.lvis.get_cat_ids()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }

        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            # info['filename'] = info['file_name'].split('_')[-1]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx):
        img_id = self.data_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(self.data_infos[idx], ann_info)

    def get_ann_info_withoutparse(self, idx):
        img_id = self.data_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return ann_info

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.data_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, img_info, ann_info):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.
        """gt_masks = []

        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            if 'iscrowd' in ann.keys():
                if ann['iscrowd']:
                    gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])

            gt_masks.append(self.lvis.ann_to_mask(ann))"""

        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        # seg_map = img_info['filename'].replace('jpg', 'png')

        ann = dict(bboxes=gt_bboxes,
                   labels=gt_labels,
                   bboxes_ignore=gt_bboxes_ignore)
        # masks=gt_masks,
        # seg_map=seg_map)

        return ann
Beispiel #14
0
class OpenimageDataset(CocoDataset):

    CLASSES = (
        'Infant bed', 'Rose', 'Flag', 'Flashlight', 'Sea turtle', 'Camera',
        'Animal', 'Glove', 'Crocodile', 'Cattle', 'House', 'Guacamole',
        'Penguin', 'Vehicle registration plate', 'Bench', 'Ladybug',
        'Human nose', 'Watermelon', 'Flute', 'Butterfly', 'Washing machine',
        'Raccoon', 'Segway', 'Taco', 'Jellyfish', 'Cake', 'Pen', 'Cannon',
        'Bread', 'Tree', 'Shellfish', 'Bed', 'Hamster', 'Hat', 'Toaster',
        'Sombrero', 'Tiara', 'Bowl', 'Dragonfly', 'Moths and butterflies',
        'Antelope', 'Vegetable', 'Torch', 'Building',
        'Power plugs and sockets', 'Blender', 'Billiard table',
        'Cutting board', 'Bronze sculpture', 'Turtle', 'Broccoli', 'Tiger',
        'Mirror', 'Bear', 'Zucchini', 'Dress', 'Volleyball', 'Guitar',
        'Reptile', 'Golf cart', 'Tart', 'Fedora', 'Carnivore', 'Car',
        'Lighthouse', 'Coffeemaker', 'Food processor', 'Truck', 'Bookcase',
        'Surfboard', 'Footwear', 'Bench', 'Necklace', 'Flower', 'Radish',
        'Marine mammal', 'Frying pan', 'Tap', 'Peach', 'Knife', 'Handbag',
        'Laptop', 'Tent', 'Ambulance', 'Christmas tree', 'Eagle', 'Limousine',
        'Kitchen & dining room table', 'Polar bear', 'Tower', 'Football',
        'Willow', 'Human head', 'Stop sign', 'Banana', 'Mixer', 'Binoculars',
        'Dessert', 'Bee', 'Chair', 'Wood-burning stove', 'Flowerpot', 'Beaker',
        'Oyster', 'Woodpecker', 'Harp', 'Bathtub', 'Wall clock',
        'Sports uniform', 'Rhinoceros', 'Beehive', 'Cupboard',
        'Chicken', 'Man', 'Blue jay', 'Cucumber', 'Balloon', 'Kite',
        'Fireplace', 'Lantern', 'Missile', 'Book', 'Spoon', 'Grapefruit',
        'Squirrel', 'Orange', 'Coat', 'Punching bag', 'Zebra', 'Billboard',
        'Bicycle', 'Door handle', 'Mechanical fan', 'Ring binder', 'Table',
        'Parrot', 'Sock', 'Vase', 'Weapon', 'Shotgun', 'Glasses', 'Seahorse',
        'Belt', 'Watercraft', 'Window', 'Giraffe', 'Lion', 'Tire', 'Vehicle',
        'Canoe', 'Tie', 'Shelf', 'Picture frame', 'Printer', 'Human leg',
        'Boat', 'Slow cooker', 'Croissant', 'Candle', 'Pancake', 'Pillow',
        'Coin', 'Stretcher', 'Sandal', 'Woman', 'Stairs', 'Harpsichord',
        'Stool', 'Bus', 'Suitcase', 'Human mouth', 'Juice', 'Skull', 'Door',
        'Violin', 'Chopsticks', 'Digital clock', 'Sunflower', 'Leopard',
        'Bell pepper', 'Harbor seal', 'Snake', 'Sewing machine', 'Goose',
        'Helicopter', 'Seat belt', 'Coffee cup', 'Microwave oven', 'Hot dog',
        'Countertop', 'Serving tray', 'Dog bed', 'Beer', 'Sunglasses',
        'Golf ball', 'Waffle', 'Palm tree', 'Trumpet', 'Ruler', 'Helmet',
        'Ladder', 'Office building', 'Tablet computer', 'Toilet paper',
        'Pomegranate', 'Skirt', 'Gas stove', 'Cookie', 'Cart', 'Raven', 'Egg',
        'Burrito', 'Goat', 'Kitchen knife', 'Skateboard',
        'Salt and pepper shakers', 'Lynx', 'Boot', 'Platter', 'Ski',
        'Swimwear', 'Swimming pool', 'Drinking straw', 'Wrench', 'Drum', 'Ant',
        'Human ear', 'Headphones', 'Fountain', 'Bird', 'Jeans', 'Television',
        'Crab', 'Microphone', 'Home appliance', 'Snowplow', 'Beetle',
        'Artichoke', 'Jet ski', 'Stationary bicycle', 'Human hair',
        'Brown bear', 'Starfish', 'Fork', 'Lobster', 'Corded phone', 'Drink',
        'Saucer', 'Carrot', 'Insect', 'Clock', 'Castle', 'Tennis racket',
        'Ceiling fan', 'Asparagus', 'Jaguar', 'Musical instrument', 'Train',
        'Cat', 'Rifle', 'Dumbbell', 'Mobile phone', 'Taxi', 'Shower',
        'Pitcher', 'Lemon', 'Invertebrate', 'Turkey', 'High heels', 'Bust',
        'Elephant', 'Scarf', 'Barrel', 'Trombone', 'Pumpkin', 'Box', 'Tomato',
        'Frog', 'Bidet', 'Human face', 'Houseplant', 'Van', 'Shark',
        'Ice cream', 'Swim cap', 'Falcon', 'Ostrich', 'Handgun', 'Whiteboard',
        'Lizard', 'Pasta', 'Snowmobile', 'Light bulb', 'Window blind',
        'Muffin', 'Pretzel', 'Computer monitor', 'Horn', 'Furniture',
        'Sandwich', 'Fox', 'Convenience store', 'Fish', 'Fruit', 'Earrings',
        'Curtain', 'Grape', 'Sofa bed', 'Horse', 'Luggage and bags', 'Desk',
        'Crutch', 'Bicycle helmet', 'Tick', 'Airplane', 'Canary', 'Spatula',
        'Watch', 'Lily', 'Kitchen appliance', 'Filing cabinet', 'Aircraft',
        'Cake stand', 'Candy', 'Sink', 'Mouse', 'Wine', 'Wheelchair',
        'Goldfish', 'Refrigerator', 'French fries', 'Drawer', 'Treadmill',
        'Picnic basket', 'Dice', 'Cabbage', 'Football helmet', 'Pig', 'Person',
        'Shorts', 'Gondola', 'Honeycomb', 'Doughnut', 'Chest of drawers',
        'Land vehicle', 'Bat', 'Monkey', 'Dagger', 'Tableware', 'Human foot',
        'Mug', 'Alarm clock', 'Pressure cooker', 'Human hand', 'Tortoise',
        'Baseball glove', 'Sword', 'Pear', 'Miniskirt', 'Traffic sign', 'Girl',
        'Roller skates', 'Dinosaur', 'Porch', 'Human beard',
        'Submarine sandwich', 'Screwdriver', 'Strawberry', 'Wine glass',
        'Seafood', 'Racket', 'Wheel', 'Sea lion', 'Toy', 'Tea', 'Tennis ball',
        'Waste container', 'Mule', 'Cricket ball', 'Pineapple', 'Coconut',
        'Doll', 'Coffee table', 'Snowman', 'Lavender', 'Shrimp', 'Maple',
        'Cowboy hat', 'Goggles', 'Rugby ball', 'Caterpillar', 'Poster',
        'Rocket', 'Organ', 'Saxophone', 'Traffic light', 'Cocktail',
        'Plastic bag', 'Squash', 'Mushroom', 'Hamburger', 'Light switch',
        'Parachute', 'Teddy bear', 'Winter melon', 'Deer', 'Musical keyboard',
        'Plumbing fixture', 'Scoreboard', 'Baseball bat', 'Envelope',
        'Adhesive tape', 'Briefcase', 'Paddle', 'Bow and arrow', 'Telephone',
        'Sheep', 'Jacket', 'Boy', 'Pizza', 'Otter', 'Office supplies', 'Couch',
        'Cello', 'Bull', 'Camel', 'Ball', 'Duck', 'Whale', 'Shirt', 'Tank',
        'Motorcycle', 'Accordion', 'Owl', 'Porcupine', 'Sun hat', 'Nail',
        'Scissors', 'Swan', 'Lamp', 'Crown', 'Piano', 'Sculpture', 'Cheetah',
        'Oboe', 'Tin can', 'Mango', 'Tripod', 'Oven', 'Mouse', 'Barge',
        'Coffee', 'Snowboard', 'Common fig', 'Salad', 'Marine invertebrates',
        'Umbrella', 'Kangaroo', 'Human arm', 'Measuring cup', 'Snail',
        'Loveseat', 'Suit', 'Teapot', 'Bottle', 'Alpaca', 'Kettle', 'Trousers',
        'Popcorn', 'Centipede', 'Spider', 'Sparrow', 'Plate', 'Bagel',
        'Personal care', 'Apple', 'Brassiere', 'Bathroom cabinet',
        'studio couch', 'Computer keyboard', 'Table tennis racket', 'Sushi',
        'Cabinetry', 'Street light', 'Towel', 'Nightstand', 'Rabbit',
        'Dolphin', 'Dog', 'Jug', 'Wok', 'Fire hydrant', 'Human eye',
        'Skyscraper', 'Backpack', 'Potato', 'Paper towel', 'Lifejacket',
        'Bicycle wheel', 'Toilet')

    def load_annotations(self, ann_file):
        """Load annotation from COCO style annotation file.

        Args:
            ann_file (str): Path of annotation file.

        Returns:
            list[dict]: Annotation info from COCO api.
        """
        try:
            from lvis import LVIS
        except ImportError:
            raise ImportError('Please follow config/lvis/README.md to '
                              'install open-mmlab forked lvis first.')
        self.coco = LVIS(ann_file)
        self.cat_ids = self.coco.get_cat_ids()
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            info['filename'] = info['file_name']
            data_infos.append(info)
        return data_infos

    def evaluate(self,
                 results,
                 metric='bbox',
                 logger=None,
                 jsonfile_prefix=None,
                 classwise=False,
                 groupwise=True,
                 proposal_nums=(100, 300, 1000),
                 iou_thrs=np.array([0.5])):
        """Evaluation in LVIS protocol.

        Args:
            results (list[list | tuple]): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated. Options are
                'bbox', 'segm', 'proposal', 'proposal_fast'.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            jsonfile_prefix (str | None):
            classwise (bool): Whether to evaluating the AP for each class.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thrs (Sequence[float]): IoU threshold used for evaluating
                recalls. If set to a list, the average recall of all IoUs will
                also be computed. Default: 0.5.

        Returns:
            dict[str, float]: LVIS style metrics.
        """

        try:
            from lvis import LVISResults
            from .openimage_eval import OpenimageEval
        except ImportError:
            raise ImportError('Please follow config/lvis/README.md to '
                              'install open-mmlab forked lvis first.')
        assert isinstance(results, list), 'results must be a list'
        assert len(results) == len(self), (
            'The length of results is not equal to the dataset len: {} != {}'.
            format(len(results), len(self)))

        metrics = metric if isinstance(metric, list) else [metric]
        allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError('metric {} is not supported'.format(metric))
        if iou_thrs is None:
            iou_thrs = np.linspace(.5,
                                   0.95,
                                   int(np.round((0.95 - .5) / .05)) + 1,
                                   endpoint=True)

        if jsonfile_prefix is None:
            tmp_dir = tempfile.TemporaryDirectory()
            jsonfile_prefix = osp.join(tmp_dir.name, 'results')
        else:
            tmp_dir = None
        result_files = self.results2json(results, jsonfile_prefix)

        eval_results = {}
        # get original api
        lvis_gt = self.coco
        for metric in metrics:
            msg = 'Evaluating {}...'.format(metric)
            if logger is None:
                msg = '\n' + msg
            print_log(msg, logger=logger)

            if metric == 'proposal_fast':
                ar = self.fast_eval_recall(results,
                                           proposal_nums,
                                           iou_thrs,
                                           logger='silent')
                log_msg = []
                for i, num in enumerate(proposal_nums):
                    eval_results['AR@{}'.format(num)] = ar[i]
                    log_msg.append('\nAR@{}\t{:.4f}'.format(num, ar[i]))
                log_msg = ''.join(log_msg)
                print_log(log_msg, logger=logger)
                continue

            if metric not in result_files:
                raise KeyError('{} is not in results'.format(metric))
            try:
                lvis_dt = LVISResults(lvis_gt, result_files[metric])
            except IndexError:
                print_log('The testing results of the whole dataset is empty.',
                          logger=logger,
                          level=logging.ERROR)
                break

            iou_type = 'bbox' if metric == 'proposal' else metric
            lvis_eval = OpenimageEval(lvis_gt, lvis_dt, iou_type)
            lvis_eval.params.imgIds = self.img_ids
            lvis_eval.params.iouThrs = iou_thrs
            if metric == 'proposal':
                lvis_eval.params.useCats = 0
                lvis_eval.params.maxDets = list(proposal_nums)
                lvis_eval.evaluate()
                lvis_eval.accumulate()
                lvis_eval.summarize()
                for k, v in lvis_eval.get_results().items():
                    if k.startswith('AR'):
                        val = float('{:.3f}'.format(float(v)))
                        eval_results[k] = val
            else:
                lvis_eval.evaluate()
                lvis_eval.accumulate()
                lvis_eval.summarize()
                lvis_results = lvis_eval.get_results()
                classwise = True
                if classwise:  # Compute per-category AP
                    # Compute per-category AP
                    # from https://github.com/facebookresearch/detectron2/
                    precisions = lvis_eval.eval['precision']
                    # precision: (iou, recall, cls, area range)
                    assert len(self.cat_ids) == precisions.shape[2]

                    results_per_category = []
                    for idx, catId in enumerate(self.cat_ids):
                        # area range index 0: all area ranges
                        # max dets index -1: typically 100 per image
                        nm = self.coco.load_cats([catId])[0]
                        precision = precisions[:, :, idx, 0]
                        precision = precision[precision > -1]
                        if precision.size:
                            ap = np.mean(precision)
                        else:
                            ap = float('nan')
                        results_per_category.append(
                            (f'{nm["name"]}', f'{float(ap):0.3f}'))

                    num_columns = min(6, len(results_per_category) * 2)
                    results_flatten = list(
                        itertools.chain(*results_per_category))
                    headers = ['category', 'AP'] * (num_columns // 2)
                    results_2d = itertools.zip_longest(*[
                        results_flatten[i::num_columns]
                        for i in range(num_columns)
                    ])
                    table_data = [headers]
                    table_data += [result for result in results_2d]
                    table = AsciiTable(table_data)
                    print_log('\n' + table.table, logger=logger)

                    with open(f"per-category-ap-{metric}.txt", 'w') as f:
                        f.write(table.table)

                for k, v in lvis_results.items():
                    if k.startswith('AP'):
                        key = '{}_{}'.format(metric, k)
                        val = float('{:.3f}'.format(float(v)))
                        eval_results[key] = val
                ap_summary = ' '.join([
                    '{}:{:.3f}'.format(k, float(v))
                    for k, v in lvis_results.items() if k.startswith('AP')
                ])
                eval_results['{}_mAP_copypaste'.format(metric)] = ap_summary
            lvis_eval.print_results()
        if tmp_dir is not None:
            tmp_dir.cleanup()
        return eval_results

    def _parse_ann_info(self, img_info, ann_info):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,\
                labels, masks, seg_map. "masks" are raw annotations and not \
                decoded into binary masks.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        gt_masks_ann = []
        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
            inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
            if inter_w * inter_h == 0:
                continue
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            if ann['category_id'] not in self.cat_ids:
                continue
            bbox = [x1, y1, x1 + w, y1 + h]
            if ann.get('iscrowd', False):
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])
                gt_masks_ann.append(ann.get('segmentation', None))

        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        seg_map = img_info['filename'].replace('jpg', 'png')

        ann = dict(bboxes=gt_bboxes,
                   labels=gt_labels,
                   bboxes_ignore=gt_bboxes_ignore,
                   masks=gt_masks_ann,
                   seg_map=seg_map)

        return ann
Beispiel #15
0
class LvisDataSet(CustomDataset):
    def __init__(self, samples_per_cls_file=None, **kwargs):
        self.samples_per_cls_file = samples_per_cls_file
        super(LvisDataSet, self).__init__(**kwargs)

    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        self.cat_ids = self.lvis.get_cat_ids()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.CLASSES = [_ for _ in self.cat_ids]
        self.cat_instance_count = [_ for _ in self.cat_ids]
        self.cat_image_count = [_ for _ in self.cat_ids]
        img_count_lbl = ["r", "c", "f"]
        self.freq_groups = [[] for _ in img_count_lbl]
        self.cat_group_idxs = [_ for _ in self.cat_ids]
        freq_group_count = {'f': 0, 'cf': 0, 'rcf': 0}
        self.cat_fake_idxs = {
            'f': [-1 for _ in self.cat_ids],
            'cf': [-1 for _ in self.cat_ids],
            'rcf': [-1 for _ in self.cat_ids]
        }
        self.freq_group_dict = {'rcf': (0, 1, 2), 'cf': (1, 2), 'f': (2, )}
        for value in self.lvis.cats.values():
            idx = value['id'] - 1
            self.CLASSES[idx] = value['name']
            self.cat_instance_count[idx] = value['instance_count']
            self.cat_image_count[idx] = value['image_count']
            group_idx = img_count_lbl.index(value["frequency"])
            self.freq_groups[group_idx].append(idx + 1)
            self.cat_group_idxs[idx] = group_idx
            if group_idx == 0:  # rare
                freq_group_count['rcf'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
            if group_idx == 1:  # common
                freq_group_count['rcf'] += 1
                freq_group_count['cf'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
                self.cat_fake_idxs['cf'][idx] = freq_group_count['cf']
            elif group_idx == 2:  # freq
                freq_group_count['rcf'] += 1
                freq_group_count['cf'] += 1
                freq_group_count['f'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
                self.cat_fake_idxs['cf'][idx] = freq_group_count['cf']
                self.cat_fake_idxs['f'][idx] = freq_group_count['f']

        if self.samples_per_cls_file is not None:
            with open(self.samples_per_cls_file, 'w') as file:
                file.writelines(str(x) + '\n' for x in self.cat_instance_count)

        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx, freq_groups=('rcf', )):
        img_id = self.img_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(ann_info,
                                    self.with_mask,
                                    freq_groups=freq_groups)

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, ann_info, with_mask=True, freq_groups=('rcf', )):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        assert isinstance(freq_groups, tuple)
        gt_valid_idxs = {name: [] for name in freq_groups}
        gt_count = 0
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.
        if with_mask:
            gt_masks = []
            gt_mask_polys = []
            gt_poly_lens = []
        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue

            for name in freq_groups:
                if self.cat_group_idxs[ann['category_id'] -
                                       1] in self.freq_group_dict[name]:
                    gt_valid_idxs[name].append(gt_count)
            gt_count += 1

            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            gt_bboxes.append(bbox)
            gt_labels.append(self.cat2label[ann['category_id']])
            if with_mask:
                gt_masks.append(self.lvis.ann_to_mask(ann))
                mask_polys = [
                    p for p in ann['segmentation'] if len(p) >= 6
                ]  # valid polygons have >= 3 points (6 coordinates)
                poly_lens = [len(p) for p in mask_polys]
                gt_mask_polys.append(mask_polys)
                gt_poly_lens.extend(poly_lens)
        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        ann = dict(
            bboxes=gt_bboxes,
            labels=gt_labels,
            bboxes_ignore=gt_bboxes_ignore,
            gt_valid_idxs=gt_valid_idxs  # add gt_valid_idxs
        )

        if with_mask:
            ann['masks'] = gt_masks
            # poly format is not used in the current implementation
            ann['mask_polys'] = gt_mask_polys
            ann['poly_lens'] = gt_poly_lens
        return ann

    def prepare_train_img(self, idx):
        img_info = self.img_infos[idx]
        # load image
        if 'COCO' in img_info['filename']:
            img = mmcv_custom.imread(
                osp.join(
                    self.img_prefix, img_info['filename']
                    [img_info['filename'].find('COCO_val2014_') +
                     len('COCO_val2014_'):]))
        else:
            img = mmcv_custom.imread(
                osp.join(self.img_prefix, img_info['filename']))
        # corruption
        if self.corruption is not None:
            img = corrupt(img,
                          severity=self.corruption_severity,
                          corruption_name=self.corruption)
        # load proposals if necessary
        if self.proposals is not None:
            proposals = self.proposals[idx][:self.num_max_proposals]
            # TODO: Handle empty proposals properly. Currently images with
            # no proposals are just ignored, but they can be used for
            # training in concept.
            if len(proposals) == 0:
                return None
            if not (proposals.shape[1] == 4 or proposals.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposals.shape))
            if proposals.shape[1] == 5:
                scores = proposals[:, 4, None]
                proposals = proposals[:, :4]
            else:
                scores = None

        ann = self.get_ann_info(idx, freq_groups=('rcf', 'cf', 'f'))
        gt_bboxes = ann['bboxes']
        gt_labels = ann['labels']
        gt_valid_idxs = ann['gt_valid_idxs']
        if self.with_crowd:
            gt_bboxes_ignore = ann['bboxes_ignore']

        # skip the image if there is no valid gt bbox
        if len(gt_bboxes) == 0 and self.skip_img_without_anno:
            warnings.warn('Skip the image "%s" that has no valid gt bbox' %
                          osp.join(self.img_prefix, img_info['filename']))
            return None

        # extra augmentation
        if self.extra_aug is not None:
            img, gt_bboxes, gt_labels = self.extra_aug(img, gt_bboxes,
                                                       gt_labels)

        # apply transforms
        flip = True if np.random.rand() < self.flip_ratio else False
        # randomly sample a scale
        img_scale = random_scale(self.img_scales, self.multiscale_mode)
        img, img_shape, pad_shape, scale_factor = self.img_transform(
            img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        img = img.copy()
        if self.with_seg:
            gt_seg = mmcv_custom.imread(osp.join(
                self.seg_prefix, img_info['filename'].replace('jpg', 'png')),
                                        flag='unchanged')
            gt_seg = self.seg_transform(gt_seg.squeeze(), img_scale, flip)
            gt_seg = mmcv.imrescale(gt_seg,
                                    self.seg_scale_factor,
                                    interpolation='nearest')
            gt_seg = gt_seg[None, ...]
        if self.proposals is not None:
            proposals = self.bbox_transform(proposals, img_shape, scale_factor,
                                            flip)
            proposals = np.hstack([proposals, scores
                                   ]) if scores is not None else proposals

        gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
                                        flip)

        if self.with_crowd:
            gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
                                                   scale_factor, flip)
        if self.with_mask:
            gt_masks = self.mask_transform(ann['masks'], pad_shape,
                                           scale_factor, flip)

        ori_shape = (img_info['height'], img_info['width'], 3)
        not_exhaustive_category_ids = img_info['not_exhaustive_category_ids']
        neg_category_ids = img_info['neg_category_ids']
        img_meta = dict(
            ori_shape=ori_shape,
            img_shape=img_shape,
            pad_shape=pad_shape,
            scale_factor=scale_factor,
            flip=flip,
            not_exhaustive_category_ids=not_exhaustive_category_ids,
            neg_category_ids=neg_category_ids,
            # cat_group_idxs=self.cat_group_idxs,
            cat_instance_count=self.cat_instance_count,
            freq_groups=self.freq_groups,
            cat_fake_idxs=self.cat_fake_idxs,
            freq_group_dict=self.freq_group_dict,
        )

        data = dict(
            img=DC(to_tensor(img), stack=True),
            img_meta=DC(img_meta, cpu_only=True),
            gt_bboxes=DC(to_tensor(gt_bboxes)),
            gt_valid_idxs=DC(gt_valid_idxs, cpu_only=True),
        )
        if self.proposals is not None:
            data['proposals'] = DC(to_tensor(proposals))
        if self.with_label:
            data['gt_labels'] = DC(to_tensor(gt_labels))
        if self.with_crowd:
            data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore))
        if self.with_mask:
            data['gt_masks'] = DC(gt_masks, cpu_only=True)
        if self.with_seg:
            data['gt_semantic_seg'] = DC(to_tensor(gt_seg), stack=True)
        return data

    def prepare_test_img(self, idx):
        """Prepare an image for testing (multi-scale and flipping)"""
        img_info = self.img_infos[idx]
        # load image
        if 'COCO' in img_info['filename']:
            img = mmcv_custom.imread(
                osp.join(
                    self.img_prefix, img_info['filename']
                    [img_info['filename'].find('COCO_val2014_') +
                     len('COCO_val2014_'):]))
        else:
            img = mmcv_custom.imread(
                osp.join(self.img_prefix, img_info['filename']))
        # corruption
        if self.corruption is not None:
            img = corrupt(img,
                          severity=self.corruption_severity,
                          corruption_name=self.corruption)
        # load proposals if necessary
        if self.proposals is not None:
            proposal = self.proposals[idx][:self.num_max_proposals]
            if not (proposal.shape[1] == 4 or proposal.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposal.shape))
        else:
            proposal = None

        def prepare_single(img, scale, flip, proposal=None):
            _img, img_shape, pad_shape, scale_factor = self.img_transform(
                img, scale, flip, keep_ratio=self.resize_keep_ratio)
            _img = to_tensor(_img)
            _img_meta = dict(ori_shape=(img_info['height'], img_info['width'],
                                        3),
                             img_shape=img_shape,
                             pad_shape=pad_shape,
                             scale_factor=scale_factor,
                             flip=flip)
            if proposal is not None:
                if proposal.shape[1] == 5:
                    score = proposal[:, 4, None]
                    proposal = proposal[:, :4]
                else:
                    score = None
                _proposal = self.bbox_transform(proposal, img_shape,
                                                scale_factor, flip)
                _proposal = np.hstack([_proposal, score
                                       ]) if score is not None else _proposal
                _proposal = to_tensor(_proposal)
            else:
                _proposal = None
            return _img, _img_meta, _proposal

        imgs = []
        img_metas = []
        proposals = []
        for scale in self.img_scales:
            _img, _img_meta, _proposal = prepare_single(
                img, scale, False, proposal)
            imgs.append(_img)
            img_metas.append(DC(_img_meta, cpu_only=True))
            proposals.append(_proposal)
            if self.flip_ratio > 0:
                _img, _img_meta, _proposal = prepare_single(
                    img, scale, True, proposal)
                imgs.append(_img)
                img_metas.append(DC(_img_meta, cpu_only=True))
                proposals.append(_proposal)
        data = dict(img=imgs, img_meta=img_metas)
        if self.proposals is not None:
            data['proposals'] = proposals
        return data
Beispiel #16
0
class LvisDataset(CustomDataset):

    CLASSES = []
    for i in range(1, 1231):
        CLASSES.append(str(i))
    CLASSES = tuple(CLASSES)

    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        self.cat_ids = self.lvis.get_cat_ids()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx):
        img_id = self.img_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(imgIds=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(self.img_infos[idx], ann_info)

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, img_info, ann_info):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, seg_map. "masks" are raw annotations and not
                decoded into binary masks.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        gt_masks_ann = []

        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            if ann.get('iscrowd', False):
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])
                gt_masks_ann.append(ann['segmentation'])

        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        seg_map = img_info['filename'].replace('jpg', 'png')

        ann = dict(bboxes=gt_bboxes,
                   labels=gt_labels,
                   bboxes_ignore=gt_bboxes_ignore,
                   masks=gt_masks_ann,
                   seg_map=seg_map)

        return ann