Exemplo n.º 1
0
class LVISV1Dataset(LVISDataset):

    CLASSES = (
        'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol',
        'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna',
        'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
        'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
        'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
        'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
        'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
        'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
        'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
        'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
        'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
        'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
        'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
        'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
        'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
        'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
        'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
        'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
        'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
        'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
        'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
        'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
        'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
        'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
        'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)',
        'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box',
        'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
        'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase',
        'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts',
        'bubble_gum', 'bucket', 'horse_buggy', 'bull', 'bulldog', 'bulldozer',
        'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn',
        'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card',
        'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
        'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
        'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
        'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar',
        'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup',
        'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
        'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
        'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
        'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
        'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower',
        'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone',
        'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier',
        'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard',
        'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime',
        'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
        'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
        'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
        'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
        'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine',
        'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock',
        'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster',
        'coat', 'coat_hanger', 'coatrack', 'c**k', 'cockroach',
        'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table',
        'coffeepot', 'coil', 'coin', 'colander', 'coleslaw',
        'coloring_material', 'combination_lock', 'pacifier', 'comic_book',
        'compass', 'computer_keyboard', 'condiment', 'cone', 'control',
        'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
        'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
        'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
        'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
        'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
        'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
        'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
        'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
        'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
        'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
        'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
        'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
        'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux',
        'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
        'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
        'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
        'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
        'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
        'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)',
        'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell',
        'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring',
        'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
        'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
        'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
        'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
        'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm',
        'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace',
        'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl',
        'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap',
        'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
        'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
        'folding_chair', 'food_processor', 'football_(American)',
        'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
        'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
        'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage',
        'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic',
        'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator',
        'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture',
        'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
        'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
        'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat',
        'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly',
        'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet',
        'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock',
        'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
        'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
        'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband',
        'headboard', 'headlight', 'headscarf', 'headset',
        'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
        'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
        'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
        'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
        'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
        'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
        'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
        'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
        'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
        'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
        'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
        'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
        'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
        'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
        'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
        'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce',
        'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
        'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
        'lizard', 'log', 'lollipop', 'speaker_(stero_equipment)', 'loveseat',
        'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
        'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger',
        'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato',
        'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox',
        'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine',
        'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone',
        'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror',
        'mitten', 'mixer_(kitchen_tool)', 'money',
        'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
        'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)',
        'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
        'music_stool', 'musical_instrument', 'nailfile', 'napkin',
        'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper',
        'newsstand', 'nightshirt', 'nosebag_(for_animals)',
        'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
        'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
        'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich',
        'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad',
        'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas',
        'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake',
        'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
        'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol',
        'parchment', 'parka', 'parking_meter', 'parrot',
        'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
        'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
        'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
        'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
        'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
        'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
        'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
        'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
        'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
        'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
        'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
        'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
        'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
        'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
        'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato',
        'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel',
        'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
        'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
        'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
        'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
        'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
        'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
        'recliner', 'record_player', 'reflector', 'remote_control',
        'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
        'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
        'rolling_pin', 'root_beer', 'router_(computer_equipment)',
        'rubber_band', 'runner_(carpet)', 'plastic_bag',
        'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
        'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
        'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
        'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
        'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
        'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
        'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
        'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
        'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
        'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
        'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
        'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
        'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
        'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
        'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
        'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
        'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
        'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
        'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
        'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
        'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
        'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
        'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
        'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer',
        'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer',
        'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign',
        'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl',
        'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses',
        'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband',
        'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword',
        'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
        'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
        'tambourine', 'army_tank', 'tank_(storage_vessel)',
        'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
        'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
        'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
        'telephone_pole', 'telephoto_lens', 'television_camera',
        'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
        'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
        'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil',
        'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven',
        'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush',
        'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel',
        'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light',
        'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline',
        'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle',
        'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat',
        'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
        'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
        'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
        'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
        'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
        'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
        'washbasin', 'automatic_washer', 'watch', 'water_bottle',
        'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
        'water_gun', 'water_scooter', 'water_ski', 'water_tower',
        'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
        'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
        'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
        'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
        'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
        'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
        'yoke_(animal_equipment)', 'zebra', 'zucchini')

    def load_annotations(self, ann_file):
        try:
            import lvis
            assert lvis.__version__ >= '10.5.3'
            from lvis import LVIS
        except AssertionError:
            raise AssertionError('Incompatible version of lvis is installed. '
                                 'Run pip uninstall lvis first. Then run pip '
                                 'install mmlvis to install open-mmlab forked '
                                 'lvis. ')
        except ImportError:
            raise ImportError('Package lvis is not installed. Please run pip '
                              'install mmlvis to install open-mmlab forked '
                              'lvis.')
        self.coco = LVIS(ann_file)
        assert not self.custom_classes, 'LVIS custom classes is not supported'
        self.cat_ids = self.coco.get_cat_ids()
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            # coco_url is used in LVISv1 instead of file_name
            # e.g. http://images.cocodataset.org/train2017/000000391895.jpg
            # train/val split in specified in url
            info['filename'] = info['coco_url'].replace(
                'http://images.cocodataset.org/', '')
            data_infos.append(info)
        return data_infos
Exemplo n.º 2
0
class LVISEvalCustom(LVISEval):
    def __init__(self, lvis_gt, lvis_dt, iou_type="segm"):
        """Constructor for LVISEval.
        Args:
            lvis_gt (LVIS class instance, or str containing path of annotation file)
            lvis_dt (LVISResult class instance, or str containing path of result file,
            or list of dict)
            iou_type (str): segm or bbox evaluation
        """
        self.logger = logging.getLogger(__name__)

        if iou_type not in ["bbox", "segm"]:
            raise ValueError("iou_type: {} is not supported.".format(iou_type))

        if isinstance(lvis_gt, LVIS):
            self.lvis_gt = lvis_gt
        elif isinstance(lvis_gt, str):
            self.lvis_gt = LVIS(lvis_gt)
        else:
            raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt))

        if isinstance(lvis_dt, LVISResults):
            self.lvis_dt = lvis_dt
        elif isinstance(lvis_dt, (str, list)):
            # set max_dets=-1 to avoid ignoring
            self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt, max_dets=-1)
        else:
            raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt))

        # per-image per-category evaluation results
        self.eval_imgs = defaultdict(list)
        self.eval = {}  # accumulated evaluation results
        self._gts = defaultdict(list)  # gt for evaluation
        self._dts = defaultdict(list)  # dt for evaluation
        self.params = ParamsCustom(iou_type=iou_type)  # parameters
        self.results = OrderedDict()
        self.ious = {}  # ious between all gts and dts

        self.params.img_ids = sorted(self.lvis_gt.get_img_ids())
        self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids())

    def _prepare(self):
        """Prepare self._gts and self._dts for evaluation based on params."""
        cat_ids = self.params.cat_ids if self.params.cat_ids else None

        gts = self.lvis_gt.load_anns(
            self.lvis_gt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)
        )
        dts = self.lvis_dt.load_anns(
            self.lvis_dt.get_ann_ids(img_ids=self.params.img_ids,
                                     cat_ids=None if self.params.use_proposal else cat_ids)
        )
        # convert ground truth to mask if iou_type == 'segm'
        if self.params.iou_type == "segm":
            self._to_mask(gts, self.lvis_gt)
            self._to_mask(dts, self.lvis_dt)

        # set ignore flag
        for gt in gts:
            if "ignore" not in gt:
                gt["ignore"] = 0

        for gt in gts:
            self._gts[gt["image_id"], gt["category_id"]].append(gt)

        # For federated dataset evaluation we will filter out all dt for an
        # image which belong to categories not present in gt and not present in
        # the negative list for an image. In other words detector is not penalized
        # for categories about which we don't have gt information about their
        # presence or absence in an image.
        img_data = self.lvis_gt.load_imgs(ids=self.params.img_ids)
        # per image map of categories not present in image
        img_nl = {d["id"]: d["not_exhaustive_category_ids"] for d in img_data}
        # per image list of categories present in image
        img_pl = defaultdict(set)
        for ann in gts:
            img_pl[ann["image_id"]].add(ann["category_id"])
        # per image map of categoires which have missing gt. For these
        # categories we don't penalize the detector for flase positives.
        self.img_nel = {d["id"]: d["not_exhaustive_category_ids"] for d in img_data}

        for dt in dts:
            img_id, cat_id = dt["image_id"], dt["category_id"]
            if self.params.use_proposal:  # for proposal eval
                for cat_id in img_pl[img_id]:
                    dt['category_id'] = cat_id
                    self._dts[img_id, cat_id].append(dt)
                continue
            elif cat_id not in img_nl[img_id] and cat_id not in img_pl[img_id]:
                continue
            self._dts[img_id, cat_id].append(dt)

        self.freq_groups = self._prepare_freq_group()

    def _summarize(
        self, summary_type, iou_thr=None, area_rng="all", freq_group_idx=None
    ):
        aidx = [
            idx
            for idx, _area_rng in enumerate(self.params.area_rng_lbl)
            if _area_rng == area_rng
        ]

        if summary_type == 'ap':
            s = self.eval["precision"]
            if iou_thr is not None:
                tidx = np.where(iou_thr == self.params.iou_thrs)[0]
                s = s[tidx]
            if freq_group_idx is not None:
                s = s[:, :, self.freq_groups[freq_group_idx], aidx]
            else:
                s = s[:, :, :, aidx]
        else:
            s = self.eval["recall"]
            if iou_thr is not None:
                tidx = np.where(iou_thr == self.params.iou_thrs)[0]
                s = s[tidx]
            if freq_group_idx is not None:  # add freq_group for recall
                s = s[:, self.freq_groups[freq_group_idx], aidx]
            else:
                s = s[:, :, aidx]

        if len(s[s > -1]) == 0:
            mean_s = -1
        else:
            mean_s = np.mean(s[s > -1])
        return mean_s

    def summarize(self):
        """Compute and display summary metrics for evaluation results."""
        if not self.eval:
            raise RuntimeError("Please run accumulate() first.")

        max_dets = self.params.max_dets

        self.results["AP"]   = self._summarize('ap')
        self.results["AP50"] = self._summarize('ap', iou_thr=0.50)
        self.results["AP75"] = self._summarize('ap', iou_thr=0.75)
        self.results["APs"]  = self._summarize('ap', area_rng="small")
        self.results["APm"]  = self._summarize('ap', area_rng="medium")
        self.results["APl"]  = self._summarize('ap', area_rng="large")
        self.results["APr"]  = self._summarize('ap', freq_group_idx=0)
        self.results["APc"]  = self._summarize('ap', freq_group_idx=1)
        self.results["APf"]  = self._summarize('ap', freq_group_idx=2)

        key = "AR@{}".format(max_dets)
        self.results[key] = self._summarize('ar')

        for area_rng in ["small", "medium", "large"]:
            key = "AR{}@{}".format(area_rng[0], max_dets)
            self.results[key] = self._summarize('ar', area_rng=area_rng)
        # add freq_group for recall
        for idx, freq_group in enumerate(self.params.img_count_lbl):
            key = "AR{}@{}".format(freq_group[0], max_dets)
            self.results[key] = self._summarize('ar', freq_group_idx=idx)
Exemplo n.º 3
0
class LVISV05Dataset(CocoDataset):

    CLASSES = (
        'acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
        'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
        'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
        'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
        'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
        'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
        'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
        'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
        'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
        'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
        'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
        'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
        'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead',
        'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
        'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can',
        'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench',
        'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder', 'binoculars',
        'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse',
        'birthday_cake', 'birthday_card', 'biscuit_(bread)', 'pirate_flag',
        'black_sheep', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp',
        'blinker', 'blueberry', 'boar', 'gameboard', 'boat', 'bobbin',
        'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet',
        'book', 'book_bag', 'bookcase', 'booklet', 'bookmark',
        'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet',
        'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl',
        'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
        'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
        'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
        'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
        'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
        'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
        'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
        'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
        'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
        'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
        'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
        'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
        'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
        'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
        'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
        'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
        'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
        'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
        'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
        'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
        'champagne', 'chandelier', 'chap', 'checkbook', 'checkerboard',
        'cherry', 'chessboard', 'chest_of_drawers_(furniture)',
        'chicken_(animal)', 'chicken_wire', 'chickpea', 'Chihuahua',
        'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
        'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
        'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
        'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
        'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
        'clementine', 'clip', 'clipboard', 'clock', 'clock_tower',
        'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
        'coat_hanger', 'coatrack', 'c**k', 'coconut', 'coffee_filter',
        'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin',
        'colander', 'coleslaw', 'coloring_material', 'combination_lock',
        'pacifier', 'comic_book', 'computer_keyboard', 'concrete_mixer',
        'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cookie',
        'cookie_jar', 'cooking_utensil', 'cooler_(for_food)',
        'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn',
        'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
        'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
        'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
        'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
        'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
        'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
        'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
        'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
        'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
        'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
        'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux',
        'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
        'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
        'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
        'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
        'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
        'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
        'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
        'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
        'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
        'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
        'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
        'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
        'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
        'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm',
        'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace',
        'fireplug', 'fish', 'fish_(food)', 'fishbowl', 'fishing_boat',
        'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flash',
        'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
        'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
        'food_processor', 'football_(American)', 'football_helmet',
        'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
        'freshener', 'frisbee', 'frog', 'fruit_juice', 'fruit_salad',
        'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage',
        'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic',
        'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
        'gift_wrap', 'ginger', 'giraffe', 'cincture',
        'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
        'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
        'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
        'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
        'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
        'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
        'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
        'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
        'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
        'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
        'headband', 'headboard', 'headlight', 'headscarf', 'headset',
        'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
        'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
        'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
        'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
        'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
        'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
        'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
        'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
        'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
        'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
        'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
        'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
        'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
        'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
        'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
        'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
        'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
        'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
        'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
        'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
        'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine',
        'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
        'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
        'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
        'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
        'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
        'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
        'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
        'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
        'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
        'mound_(baseball)', 'mouse_(animal_rodent)',
        'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
        'music_stool', 'musical_instrument', 'nailfile', 'nameplate', 'napkin',
        'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newsstand',
        'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)',
        'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
        'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
        'orange_(fruit)', 'orange_juice', 'oregano', 'ostrich', 'ottoman',
        'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
        'padlock', 'paintbox', 'paintbrush', 'painting', 'pajamas', 'palette',
        'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
        'papaya', 'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
        'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
        'parchment', 'parka', 'parking_meter', 'parrot',
        'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
        'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
        'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
        'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
        'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
        'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
        'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
        'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
        'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
        'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
        'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
        'plate', 'platter', 'playing_card', 'playpen', 'pliers',
        'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
        'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
        'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
        'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato',
        'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'printer',
        'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding',
        'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet',
        'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car',
        'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft',
        'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
        'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
        'recliner', 'record_player', 'red_cabbage', 'reflector',
        'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
        'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
        'Rollerblade', 'rolling_pin', 'root_beer',
        'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
        'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
        'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
        'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
        'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
        'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
        'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
        'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
        'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
        'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
        'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
        'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
        'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart',
        'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head',
        'shower_curtain', 'shredder_(for_paper)', 'sieve', 'signboard', 'silo',
        'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka',
        'ski_pole', 'skirt', 'sled', 'sleeping_bag', 'sling_(bandage)',
        'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
        'snowmobile', 'soap', 'soccer_ball', 'sock', 'soda_fountain',
        'carbonated_water', 'sofa', 'softball', 'solar_array', 'sombrero',
        'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk',
        'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear',
        'spectacles', 'spice_rack', 'spider', 'sponge', 'spoon', 'sportswear',
        'spotlight', 'squirrel', 'stapler_(stapling_machine)', 'starfish',
        'statue_(sculpture)', 'steak_(food)', 'steak_knife',
        'steamer_(kitchen_appliance)', 'steering_wheel', 'stencil',
        'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer',
        'stirrup', 'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light',
        'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
        'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
        'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
        'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
        'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
        'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
        'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
        'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
        'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
        'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
        'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
        'telephone_pole', 'telephoto_lens', 'television_camera',
        'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
        'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
        'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil',
        'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven',
        'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush',
        'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel',
        'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light',
        'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline',
        'tray', 'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
        'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
        'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
        'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
        'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
        'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
        'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
        'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
        'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
        'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
        'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski',
        'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam',
        'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair',
        'whipped_cream', 'whiskey', 'whistle', 'wick', 'wig', 'wind_chime',
        'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock',
        'wine_bottle', 'wine_bucket', 'wineglass', 'wing_chair',
        'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath',
        'wrench', 'wristband', 'wristlet', 'yacht', 'yak', 'yogurt',
        'yoke_(animal_equipment)', 'zebra', 'zucchini')

    def load_annotations(self, ann_file):
        """Load annotation from lvis style annotation file.

        Args:
            ann_file (str): Path of annotation file.

        Returns:
            list[dict]: Annotation info from LVIS api.
        """

        try:
            import lvis
            assert lvis.__version__ >= '10.5.3'
            from lvis import LVIS
        except AssertionError:
            raise AssertionError('Incompatible version of lvis is installed. '
                                 'Run pip uninstall lvis first. Then run pip '
                                 'install mmlvis to install open-mmlab forked '
                                 'lvis. ')
        except ImportError:
            raise ImportError('Package lvis is not installed. Please run pip '
                              'install mmlvis to install open-mmlab forked '
                              'lvis.')
        self.coco = LVIS(ann_file)
        assert not self.custom_classes, 'LVIS custom classes is not supported'
        self.cat_ids = self.coco.get_cat_ids()
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            if info['file_name'].startswith('COCO'):
                # Convert form the COCO 2014 file naming convention of
                # COCO_[train/val/test]2014_000000000000.jpg to the 2017
                # naming convention of 000000000000.jpg
                # (LVIS v1 will fix this naming issue)
                info['filename'] = info['file_name'][-16:]
            else:
                info['filename'] = info['file_name']
            data_infos.append(info)
        return data_infos

    def evaluate(self,
                 results,
                 metric='bbox',
                 logger=None,
                 jsonfile_prefix=None,
                 classwise=False,
                 proposal_nums=(100, 300, 1000),
                 iou_thrs=np.arange(0.5, 0.96, 0.05)):
        """Evaluation in LVIS protocol.

        Args:
            results (list[list | tuple]): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated. Options are
                'bbox', 'segm', 'proposal', 'proposal_fast'.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            jsonfile_prefix (str | None):
            classwise (bool): Whether to evaluating the AP for each class.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thrs (Sequence[float]): IoU threshold used for evaluating
                recalls. If set to a list, the average recall of all IoUs will
                also be computed. Default: 0.5.

        Returns:
            dict[str, float]: LVIS style metrics.
        """

        try:
            import lvis
            assert lvis.__version__ >= '10.5.3'
            from lvis import LVISResults, LVISEval
        except AssertionError:
            raise AssertionError('Incompatible version of lvis is installed. '
                                 'Run pip uninstall lvis first. Then run pip '
                                 'install mmlvis to install open-mmlab forked '
                                 'lvis. ')
        except ImportError:
            raise ImportError('Package lvis is not installed. Please run pip '
                              'install mmlvis to install open-mmlab forked '
                              'lvis.')
        assert isinstance(results, list), 'results must be a list'
        assert len(results) == len(self), (
            'The length of results is not equal to the dataset len: {} != {}'.
            format(len(results), len(self)))

        metrics = metric if isinstance(metric, list) else [metric]
        allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError('metric {} is not supported'.format(metric))

        if jsonfile_prefix is None:
            tmp_dir = tempfile.TemporaryDirectory()
            jsonfile_prefix = osp.join(tmp_dir.name, 'results')
        else:
            tmp_dir = None
        result_files = self.results2json(results, jsonfile_prefix)

        eval_results = {}
        # get original api
        lvis_gt = self.coco
        for metric in metrics:
            msg = 'Evaluating {}...'.format(metric)
            if logger is None:
                msg = '\n' + msg
            print_log(msg, logger=logger)

            if metric == 'proposal_fast':
                ar = self.fast_eval_recall(results,
                                           proposal_nums,
                                           iou_thrs,
                                           logger='silent')
                log_msg = []
                for i, num in enumerate(proposal_nums):
                    eval_results['AR@{}'.format(num)] = ar[i]
                    log_msg.append('\nAR@{}\t{:.4f}'.format(num, ar[i]))
                log_msg = ''.join(log_msg)
                print_log(log_msg, logger=logger)
                continue

            if metric not in result_files:
                raise KeyError('{} is not in results'.format(metric))
            try:
                lvis_dt = LVISResults(lvis_gt, result_files[metric])
            except IndexError:
                print_log('The testing results of the whole dataset is empty.',
                          logger=logger,
                          level=logging.ERROR)
                break

            iou_type = 'bbox' if metric == 'proposal' else metric
            lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type)
            lvis_eval.params.imgIds = self.img_ids
            if metric == 'proposal':
                lvis_eval.params.useCats = 0
                lvis_eval.params.maxDets = list(proposal_nums)
                lvis_eval.evaluate()
                lvis_eval.accumulate()
                lvis_eval.summarize()
                for k, v in lvis_eval.get_results().items():
                    if k.startswith('AR'):
                        val = float('{:.3f}'.format(float(v)))
                        eval_results[k] = val
            else:
                lvis_eval.evaluate()
                lvis_eval.accumulate()
                lvis_eval.summarize()
                lvis_results = lvis_eval.get_results()
                if classwise:  # Compute per-category AP
                    # Compute per-category AP
                    # from https://github.com/facebookresearch/detectron2/
                    precisions = lvis_eval.eval['precision']
                    # precision: (iou, recall, cls, area range, max dets)
                    assert len(self.cat_ids) == precisions.shape[2]

                    results_per_category = []
                    for idx, catId in enumerate(self.cat_ids):
                        # area range index 0: all area ranges
                        # max dets index -1: typically 100 per image
                        nm = self.coco.load_cats(catId)[0]
                        precision = precisions[:, :, idx, 0, -1]
                        precision = precision[precision > -1]
                        if precision.size:
                            ap = np.mean(precision)
                        else:
                            ap = float('nan')
                        results_per_category.append(
                            (f'{nm["name"]}', f'{float(ap):0.3f}'))

                    num_columns = min(6, len(results_per_category) * 2)
                    results_flatten = list(
                        itertools.chain(*results_per_category))
                    headers = ['category', 'AP'] * (num_columns // 2)
                    results_2d = itertools.zip_longest(*[
                        results_flatten[i::num_columns]
                        for i in range(num_columns)
                    ])
                    table_data = [headers]
                    table_data += [result for result in results_2d]
                    table = AsciiTable(table_data)
                    print_log('\n' + table.table, logger=logger)

                for k, v in lvis_results.items():
                    if k.startswith('AP'):
                        key = '{}_{}'.format(metric, k)
                        val = float('{:.3f}'.format(float(v)))
                        eval_results[key] = val
                ap_summary = ' '.join([
                    '{}:{:.3f}'.format(k, float(v))
                    for k, v in lvis_results.items() if k.startswith('AP')
                ])
                eval_results['{}_mAP_copypaste'.format(metric)] = ap_summary
            lvis_eval.print_results()
        if tmp_dir is not None:
            tmp_dir.cleanup()
        return eval_results
Exemplo n.º 4
0
    def __init__(self,annfile,dset_name,device='cuda',reduce='sum',reduce_mini_batch=True):
        super(IDFTransformer, self).__init__()
        cwd = os.getenv('owd')
        self.reduce=reduce
        self.reduce_mini_batch = reduce_mini_batch
        self.loss = nn.BCELoss(reduction=reduce)
        annfile = os.path.join(cwd,annfile)
        idf_path = dset_name+"_files"
        idf_path = os.path.join(cwd,idf_path)
        self.device=device 
        if not os.path.exists(idf_path):
            os.mkdir(idf_path)
        if not os.path.exists(os.path.join(idf_path,'idf.csv')):
            df = pd.DataFrame()
            if dset_name=='coco':
                coco=COCO(annfile)
                ims=[]
                last_cat = coco.getCatIds()[-1]
                num_classes = last_cat +1 # for bg
                self.num_classes=num_classes
                for imgid in coco.getImgIds():
                    anids = coco.getAnnIds(imgIds=imgid)
                    categories=[]
                    for anid in anids:
                        categories.append(coco.loadAnns(int(anid))[0]['category_id'])
                    ims.append(categories)
            else:
                lvis=LVIS(annfile)
                ims=[]
                last_cat = lvis.get_cat_ids()[-1]
                num_classes = last_cat +1 # for bg
                self.num_classes=num_classes
                for imgid in lvis.get_img_ids():
                    anids = lvis.get_ann_ids(img_ids=[imgid])
                    categories=[]
                    for annot in lvis.load_anns(anids):
                        categories.append(annot['category_id'])
                    ims.append(categories)
            
            final=0
            k=0
            print('calculating idf ...')
            for im in tqdm(ims):
                cats = np.array(im,dtype=np.int)
                cats = np.bincount(cats,minlength=self.num_classes)
                cats=np.array([cats])
                cats = coo_matrix(cats)
                if k==0:
                    final= cats
                else:
                    final = vstack([cats, final])
                k=k+1

            mask = final.sum(axis=0)>0
            mask=mask.tolist()[0]
            final = final.tocsr()[:,mask]
            self.num_classes = final.shape[1]
            doc_freq = (final>0).sum(axis=0)
            smooth = (np.log((final.shape[0]+1)/(doc_freq+1))+1).tolist()[0]
            raw = (np.log((final.shape[0])/(doc_freq))).tolist()[0]
            prob = (np.log((final.shape[0]-doc_freq)/(doc_freq))).tolist()[0]
            df['smooth'] = smooth
            df['raw'] = raw
            df['prob'] = prob
            self.idf_weights={} 
            self.idf_weights['smooth'] = torch.tensor(
                smooth, dtype=torch.float, device=self.device)
            self.idf_weights['raw'] = torch.tensor(
                raw, dtype=torch.float, device=self.device)
            self.idf_weights['prob'] = torch.tensor(
                prob, dtype=torch.float, device=self.device)
            df.to_csv(os.path.join(idf_path,'idf.csv'))

        else:
            df=pd.read_csv(os.path.join(idf_path,'idf.csv'))
            self.idf_weights = {}
            self.idf_weights['smooth'] = torch.tensor(
                df['smooth'], dtype=torch.float, device=self.device)
            self.idf_weights['raw'] = torch.tensor(
                df['raw'], dtype=torch.float, device=self.device)
            self.idf_weights['prob'] = torch.tensor(
                df['prob'], dtype=torch.float, device=self.device)

            self.num_classes = self.idf_weights['smooth'].shape[0]
Exemplo n.º 5
0
class MiniLvisDataSet(LvisDataSet):
    """Only contains the same 80 classes as COCO dataset"""
    # TODO: fix the lvis-coco classes
    LVIS_TO_COCO = {
        'computer_mouse': 'mouse',
        'cellphone': 'cell_phone',
        'hair_dryer': 'hair_drier',
        'sausage': 'hot_dog',  # not sure!!!
        'laptop_computer': 'laptop',
        'microwave_oven': 'microwave',
        'toaster_oven': 'oven',
        'flower_arrangement': 'potted_plant',  # not sure!!
        'remote_control': 'remote',
        'ski': 'skis',
        'baseball': 'sports_ball',  # too many types of balls
        'wineglass': 'wine_glass',
    }

    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        COCO_CLASSES = sorted(list(CocoDataset.CLASSES))
        self.synonyms_classes = [
            value['synonyms'] for value in self.lvis.cats.values()
        ]
        self.cat_ids = []
        self.CLASSES = []
        self.not_in_classes = COCO_CLASSES
        for id in self.lvis.get_cat_ids():
            for name in self.lvis.cats[id]['synonyms']:
                if name in self.LVIS_TO_COCO:
                    self.cat_ids.append(id)
                    self.CLASSES.append(name)
                    self.not_in_classes.remove(self.LVIS_TO_COCO[name])
                    break
                elif '(' not in name and name in COCO_CLASSES:
                    self.cat_ids.append(id)
                    self.CLASSES.append(name)
                    self.not_in_classes.remove(name)
                    break
                elif '_' in name:
                    new_name = name.split('_(')[0]
                    if new_name in COCO_CLASSES:
                        self.cat_ids.append(id)
                        self.CLASSES.append(name)
                        self.not_in_classes.remove(new_name)
                        break
        data_dir = osp.dirname(ann_file)
        with open(osp.join(data_dir, 'synonyms_classes.json'), 'w') as f:
            f.write(json.dumps(self.synonyms_classes, indent=2))
        with open(osp.join(data_dir, 'not_in_classes.json'), 'w') as f:
            f.write(json.dumps(self.not_in_classes, indent=2))
        with open(osp.join(data_dir, 'coco_classes.json'), 'w') as f:
            f.write(json.dumps(COCO_CLASSES, indent=2))
        with open(osp.join(data_dir, 'lvis_coco_classes.json'), 'w') as f:
            f.write(json.dumps(self.CLASSES, indent=2))
        self.CLASSES = tuple(self.CLASSES)
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.CLASSES = CocoDataset.CLASSES
        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx):
        img_id = self.img_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(ann_info, self.with_mask)

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds
Exemplo n.º 6
0
class LvisDataset1(CustomDataset):
    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        self.full_cat_ids = self.lvis.get_cat_ids()
        self.full_cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.full_cat_ids)
        }

        self.CLASSES = tuple(
            [item['name'] for item in self.lvis.dataset['categories']])
        self.cat_ids = self.lvis.get_cat_ids()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }

        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            # info['filename'] = info['file_name'].split('_')[-1]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx):
        img_id = self.data_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(self.data_infos[idx], ann_info)

    def get_ann_info_withoutparse(self, idx):
        img_id = self.data_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return ann_info

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.data_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, img_info, ann_info):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.
        """gt_masks = []

        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            if 'iscrowd' in ann.keys():
                if ann['iscrowd']:
                    gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])

            gt_masks.append(self.lvis.ann_to_mask(ann))"""

        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        # seg_map = img_info['filename'].replace('jpg', 'png')

        ann = dict(bboxes=gt_bboxes,
                   labels=gt_labels,
                   bboxes_ignore=gt_bboxes_ignore)
        # masks=gt_masks,
        # seg_map=seg_map)

        return ann
Exemplo n.º 7
0
class OpenimageDataset(CocoDataset):

    CLASSES = (
        'Infant bed', 'Rose', 'Flag', 'Flashlight', 'Sea turtle', 'Camera',
        'Animal', 'Glove', 'Crocodile', 'Cattle', 'House', 'Guacamole',
        'Penguin', 'Vehicle registration plate', 'Bench', 'Ladybug',
        'Human nose', 'Watermelon', 'Flute', 'Butterfly', 'Washing machine',
        'Raccoon', 'Segway', 'Taco', 'Jellyfish', 'Cake', 'Pen', 'Cannon',
        'Bread', 'Tree', 'Shellfish', 'Bed', 'Hamster', 'Hat', 'Toaster',
        'Sombrero', 'Tiara', 'Bowl', 'Dragonfly', 'Moths and butterflies',
        'Antelope', 'Vegetable', 'Torch', 'Building',
        'Power plugs and sockets', 'Blender', 'Billiard table',
        'Cutting board', 'Bronze sculpture', 'Turtle', 'Broccoli', 'Tiger',
        'Mirror', 'Bear', 'Zucchini', 'Dress', 'Volleyball', 'Guitar',
        'Reptile', 'Golf cart', 'Tart', 'Fedora', 'Carnivore', 'Car',
        'Lighthouse', 'Coffeemaker', 'Food processor', 'Truck', 'Bookcase',
        'Surfboard', 'Footwear', 'Bench', 'Necklace', 'Flower', 'Radish',
        'Marine mammal', 'Frying pan', 'Tap', 'Peach', 'Knife', 'Handbag',
        'Laptop', 'Tent', 'Ambulance', 'Christmas tree', 'Eagle', 'Limousine',
        'Kitchen & dining room table', 'Polar bear', 'Tower', 'Football',
        'Willow', 'Human head', 'Stop sign', 'Banana', 'Mixer', 'Binoculars',
        'Dessert', 'Bee', 'Chair', 'Wood-burning stove', 'Flowerpot', 'Beaker',
        'Oyster', 'Woodpecker', 'Harp', 'Bathtub', 'Wall clock',
        'Sports uniform', 'Rhinoceros', 'Beehive', 'Cupboard',
        'Chicken', 'Man', 'Blue jay', 'Cucumber', 'Balloon', 'Kite',
        'Fireplace', 'Lantern', 'Missile', 'Book', 'Spoon', 'Grapefruit',
        'Squirrel', 'Orange', 'Coat', 'Punching bag', 'Zebra', 'Billboard',
        'Bicycle', 'Door handle', 'Mechanical fan', 'Ring binder', 'Table',
        'Parrot', 'Sock', 'Vase', 'Weapon', 'Shotgun', 'Glasses', 'Seahorse',
        'Belt', 'Watercraft', 'Window', 'Giraffe', 'Lion', 'Tire', 'Vehicle',
        'Canoe', 'Tie', 'Shelf', 'Picture frame', 'Printer', 'Human leg',
        'Boat', 'Slow cooker', 'Croissant', 'Candle', 'Pancake', 'Pillow',
        'Coin', 'Stretcher', 'Sandal', 'Woman', 'Stairs', 'Harpsichord',
        'Stool', 'Bus', 'Suitcase', 'Human mouth', 'Juice', 'Skull', 'Door',
        'Violin', 'Chopsticks', 'Digital clock', 'Sunflower', 'Leopard',
        'Bell pepper', 'Harbor seal', 'Snake', 'Sewing machine', 'Goose',
        'Helicopter', 'Seat belt', 'Coffee cup', 'Microwave oven', 'Hot dog',
        'Countertop', 'Serving tray', 'Dog bed', 'Beer', 'Sunglasses',
        'Golf ball', 'Waffle', 'Palm tree', 'Trumpet', 'Ruler', 'Helmet',
        'Ladder', 'Office building', 'Tablet computer', 'Toilet paper',
        'Pomegranate', 'Skirt', 'Gas stove', 'Cookie', 'Cart', 'Raven', 'Egg',
        'Burrito', 'Goat', 'Kitchen knife', 'Skateboard',
        'Salt and pepper shakers', 'Lynx', 'Boot', 'Platter', 'Ski',
        'Swimwear', 'Swimming pool', 'Drinking straw', 'Wrench', 'Drum', 'Ant',
        'Human ear', 'Headphones', 'Fountain', 'Bird', 'Jeans', 'Television',
        'Crab', 'Microphone', 'Home appliance', 'Snowplow', 'Beetle',
        'Artichoke', 'Jet ski', 'Stationary bicycle', 'Human hair',
        'Brown bear', 'Starfish', 'Fork', 'Lobster', 'Corded phone', 'Drink',
        'Saucer', 'Carrot', 'Insect', 'Clock', 'Castle', 'Tennis racket',
        'Ceiling fan', 'Asparagus', 'Jaguar', 'Musical instrument', 'Train',
        'Cat', 'Rifle', 'Dumbbell', 'Mobile phone', 'Taxi', 'Shower',
        'Pitcher', 'Lemon', 'Invertebrate', 'Turkey', 'High heels', 'Bust',
        'Elephant', 'Scarf', 'Barrel', 'Trombone', 'Pumpkin', 'Box', 'Tomato',
        'Frog', 'Bidet', 'Human face', 'Houseplant', 'Van', 'Shark',
        'Ice cream', 'Swim cap', 'Falcon', 'Ostrich', 'Handgun', 'Whiteboard',
        'Lizard', 'Pasta', 'Snowmobile', 'Light bulb', 'Window blind',
        'Muffin', 'Pretzel', 'Computer monitor', 'Horn', 'Furniture',
        'Sandwich', 'Fox', 'Convenience store', 'Fish', 'Fruit', 'Earrings',
        'Curtain', 'Grape', 'Sofa bed', 'Horse', 'Luggage and bags', 'Desk',
        'Crutch', 'Bicycle helmet', 'Tick', 'Airplane', 'Canary', 'Spatula',
        'Watch', 'Lily', 'Kitchen appliance', 'Filing cabinet', 'Aircraft',
        'Cake stand', 'Candy', 'Sink', 'Mouse', 'Wine', 'Wheelchair',
        'Goldfish', 'Refrigerator', 'French fries', 'Drawer', 'Treadmill',
        'Picnic basket', 'Dice', 'Cabbage', 'Football helmet', 'Pig', 'Person',
        'Shorts', 'Gondola', 'Honeycomb', 'Doughnut', 'Chest of drawers',
        'Land vehicle', 'Bat', 'Monkey', 'Dagger', 'Tableware', 'Human foot',
        'Mug', 'Alarm clock', 'Pressure cooker', 'Human hand', 'Tortoise',
        'Baseball glove', 'Sword', 'Pear', 'Miniskirt', 'Traffic sign', 'Girl',
        'Roller skates', 'Dinosaur', 'Porch', 'Human beard',
        'Submarine sandwich', 'Screwdriver', 'Strawberry', 'Wine glass',
        'Seafood', 'Racket', 'Wheel', 'Sea lion', 'Toy', 'Tea', 'Tennis ball',
        'Waste container', 'Mule', 'Cricket ball', 'Pineapple', 'Coconut',
        'Doll', 'Coffee table', 'Snowman', 'Lavender', 'Shrimp', 'Maple',
        'Cowboy hat', 'Goggles', 'Rugby ball', 'Caterpillar', 'Poster',
        'Rocket', 'Organ', 'Saxophone', 'Traffic light', 'Cocktail',
        'Plastic bag', 'Squash', 'Mushroom', 'Hamburger', 'Light switch',
        'Parachute', 'Teddy bear', 'Winter melon', 'Deer', 'Musical keyboard',
        'Plumbing fixture', 'Scoreboard', 'Baseball bat', 'Envelope',
        'Adhesive tape', 'Briefcase', 'Paddle', 'Bow and arrow', 'Telephone',
        'Sheep', 'Jacket', 'Boy', 'Pizza', 'Otter', 'Office supplies', 'Couch',
        'Cello', 'Bull', 'Camel', 'Ball', 'Duck', 'Whale', 'Shirt', 'Tank',
        'Motorcycle', 'Accordion', 'Owl', 'Porcupine', 'Sun hat', 'Nail',
        'Scissors', 'Swan', 'Lamp', 'Crown', 'Piano', 'Sculpture', 'Cheetah',
        'Oboe', 'Tin can', 'Mango', 'Tripod', 'Oven', 'Mouse', 'Barge',
        'Coffee', 'Snowboard', 'Common fig', 'Salad', 'Marine invertebrates',
        'Umbrella', 'Kangaroo', 'Human arm', 'Measuring cup', 'Snail',
        'Loveseat', 'Suit', 'Teapot', 'Bottle', 'Alpaca', 'Kettle', 'Trousers',
        'Popcorn', 'Centipede', 'Spider', 'Sparrow', 'Plate', 'Bagel',
        'Personal care', 'Apple', 'Brassiere', 'Bathroom cabinet',
        'studio couch', 'Computer keyboard', 'Table tennis racket', 'Sushi',
        'Cabinetry', 'Street light', 'Towel', 'Nightstand', 'Rabbit',
        'Dolphin', 'Dog', 'Jug', 'Wok', 'Fire hydrant', 'Human eye',
        'Skyscraper', 'Backpack', 'Potato', 'Paper towel', 'Lifejacket',
        'Bicycle wheel', 'Toilet')

    def load_annotations(self, ann_file):
        """Load annotation from COCO style annotation file.

        Args:
            ann_file (str): Path of annotation file.

        Returns:
            list[dict]: Annotation info from COCO api.
        """
        try:
            from lvis import LVIS
        except ImportError:
            raise ImportError('Please follow config/lvis/README.md to '
                              'install open-mmlab forked lvis first.')
        self.coco = LVIS(ann_file)
        self.cat_ids = self.coco.get_cat_ids()
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            info['filename'] = info['file_name']
            data_infos.append(info)
        return data_infos

    def evaluate(self,
                 results,
                 metric='bbox',
                 logger=None,
                 jsonfile_prefix=None,
                 classwise=False,
                 groupwise=True,
                 proposal_nums=(100, 300, 1000),
                 iou_thrs=np.array([0.5])):
        """Evaluation in LVIS protocol.

        Args:
            results (list[list | tuple]): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated. Options are
                'bbox', 'segm', 'proposal', 'proposal_fast'.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            jsonfile_prefix (str | None):
            classwise (bool): Whether to evaluating the AP for each class.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thrs (Sequence[float]): IoU threshold used for evaluating
                recalls. If set to a list, the average recall of all IoUs will
                also be computed. Default: 0.5.

        Returns:
            dict[str, float]: LVIS style metrics.
        """

        try:
            from lvis import LVISResults
            from .openimage_eval import OpenimageEval
        except ImportError:
            raise ImportError('Please follow config/lvis/README.md to '
                              'install open-mmlab forked lvis first.')
        assert isinstance(results, list), 'results must be a list'
        assert len(results) == len(self), (
            'The length of results is not equal to the dataset len: {} != {}'.
            format(len(results), len(self)))

        metrics = metric if isinstance(metric, list) else [metric]
        allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError('metric {} is not supported'.format(metric))
        if iou_thrs is None:
            iou_thrs = np.linspace(.5,
                                   0.95,
                                   int(np.round((0.95 - .5) / .05)) + 1,
                                   endpoint=True)

        if jsonfile_prefix is None:
            tmp_dir = tempfile.TemporaryDirectory()
            jsonfile_prefix = osp.join(tmp_dir.name, 'results')
        else:
            tmp_dir = None
        result_files = self.results2json(results, jsonfile_prefix)

        eval_results = {}
        # get original api
        lvis_gt = self.coco
        for metric in metrics:
            msg = 'Evaluating {}...'.format(metric)
            if logger is None:
                msg = '\n' + msg
            print_log(msg, logger=logger)

            if metric == 'proposal_fast':
                ar = self.fast_eval_recall(results,
                                           proposal_nums,
                                           iou_thrs,
                                           logger='silent')
                log_msg = []
                for i, num in enumerate(proposal_nums):
                    eval_results['AR@{}'.format(num)] = ar[i]
                    log_msg.append('\nAR@{}\t{:.4f}'.format(num, ar[i]))
                log_msg = ''.join(log_msg)
                print_log(log_msg, logger=logger)
                continue

            if metric not in result_files:
                raise KeyError('{} is not in results'.format(metric))
            try:
                lvis_dt = LVISResults(lvis_gt, result_files[metric])
            except IndexError:
                print_log('The testing results of the whole dataset is empty.',
                          logger=logger,
                          level=logging.ERROR)
                break

            iou_type = 'bbox' if metric == 'proposal' else metric
            lvis_eval = OpenimageEval(lvis_gt, lvis_dt, iou_type)
            lvis_eval.params.imgIds = self.img_ids
            lvis_eval.params.iouThrs = iou_thrs
            if metric == 'proposal':
                lvis_eval.params.useCats = 0
                lvis_eval.params.maxDets = list(proposal_nums)
                lvis_eval.evaluate()
                lvis_eval.accumulate()
                lvis_eval.summarize()
                for k, v in lvis_eval.get_results().items():
                    if k.startswith('AR'):
                        val = float('{:.3f}'.format(float(v)))
                        eval_results[k] = val
            else:
                lvis_eval.evaluate()
                lvis_eval.accumulate()
                lvis_eval.summarize()
                lvis_results = lvis_eval.get_results()
                classwise = True
                if classwise:  # Compute per-category AP
                    # Compute per-category AP
                    # from https://github.com/facebookresearch/detectron2/
                    precisions = lvis_eval.eval['precision']
                    # precision: (iou, recall, cls, area range)
                    assert len(self.cat_ids) == precisions.shape[2]

                    results_per_category = []
                    for idx, catId in enumerate(self.cat_ids):
                        # area range index 0: all area ranges
                        # max dets index -1: typically 100 per image
                        nm = self.coco.load_cats([catId])[0]
                        precision = precisions[:, :, idx, 0]
                        precision = precision[precision > -1]
                        if precision.size:
                            ap = np.mean(precision)
                        else:
                            ap = float('nan')
                        results_per_category.append(
                            (f'{nm["name"]}', f'{float(ap):0.3f}'))

                    num_columns = min(6, len(results_per_category) * 2)
                    results_flatten = list(
                        itertools.chain(*results_per_category))
                    headers = ['category', 'AP'] * (num_columns // 2)
                    results_2d = itertools.zip_longest(*[
                        results_flatten[i::num_columns]
                        for i in range(num_columns)
                    ])
                    table_data = [headers]
                    table_data += [result for result in results_2d]
                    table = AsciiTable(table_data)
                    print_log('\n' + table.table, logger=logger)

                    with open(f"per-category-ap-{metric}.txt", 'w') as f:
                        f.write(table.table)

                for k, v in lvis_results.items():
                    if k.startswith('AP'):
                        key = '{}_{}'.format(metric, k)
                        val = float('{:.3f}'.format(float(v)))
                        eval_results[key] = val
                ap_summary = ' '.join([
                    '{}:{:.3f}'.format(k, float(v))
                    for k, v in lvis_results.items() if k.startswith('AP')
                ])
                eval_results['{}_mAP_copypaste'.format(metric)] = ap_summary
            lvis_eval.print_results()
        if tmp_dir is not None:
            tmp_dir.cleanup()
        return eval_results

    def _parse_ann_info(self, img_info, ann_info):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,\
                labels, masks, seg_map. "masks" are raw annotations and not \
                decoded into binary masks.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        gt_masks_ann = []
        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
            inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
            if inter_w * inter_h == 0:
                continue
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            if ann['category_id'] not in self.cat_ids:
                continue
            bbox = [x1, y1, x1 + w, y1 + h]
            if ann.get('iscrowd', False):
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])
                gt_masks_ann.append(ann.get('segmentation', None))

        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        seg_map = img_info['filename'].replace('jpg', 'png')

        ann = dict(bboxes=gt_bboxes,
                   labels=gt_labels,
                   bboxes_ignore=gt_bboxes_ignore,
                   masks=gt_masks_ann,
                   seg_map=seg_map)

        return ann
Exemplo n.º 8
0
class LvisDataSet(CustomDataset):
    def __init__(self, samples_per_cls_file=None, **kwargs):
        self.samples_per_cls_file = samples_per_cls_file
        super(LvisDataSet, self).__init__(**kwargs)

    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        self.cat_ids = self.lvis.get_cat_ids()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.CLASSES = [_ for _ in self.cat_ids]
        self.cat_instance_count = [_ for _ in self.cat_ids]
        self.cat_image_count = [_ for _ in self.cat_ids]
        img_count_lbl = ["r", "c", "f"]
        self.freq_groups = [[] for _ in img_count_lbl]
        self.cat_group_idxs = [_ for _ in self.cat_ids]
        freq_group_count = {'f': 0, 'cf': 0, 'rcf': 0}
        self.cat_fake_idxs = {
            'f': [-1 for _ in self.cat_ids],
            'cf': [-1 for _ in self.cat_ids],
            'rcf': [-1 for _ in self.cat_ids]
        }
        self.freq_group_dict = {'rcf': (0, 1, 2), 'cf': (1, 2), 'f': (2, )}
        for value in self.lvis.cats.values():
            idx = value['id'] - 1
            self.CLASSES[idx] = value['name']
            self.cat_instance_count[idx] = value['instance_count']
            self.cat_image_count[idx] = value['image_count']
            group_idx = img_count_lbl.index(value["frequency"])
            self.freq_groups[group_idx].append(idx + 1)
            self.cat_group_idxs[idx] = group_idx
            if group_idx == 0:  # rare
                freq_group_count['rcf'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
            if group_idx == 1:  # common
                freq_group_count['rcf'] += 1
                freq_group_count['cf'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
                self.cat_fake_idxs['cf'][idx] = freq_group_count['cf']
            elif group_idx == 2:  # freq
                freq_group_count['rcf'] += 1
                freq_group_count['cf'] += 1
                freq_group_count['f'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
                self.cat_fake_idxs['cf'][idx] = freq_group_count['cf']
                self.cat_fake_idxs['f'][idx] = freq_group_count['f']

        if self.samples_per_cls_file is not None:
            with open(self.samples_per_cls_file, 'w') as file:
                file.writelines(str(x) + '\n' for x in self.cat_instance_count)

        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx, freq_groups=('rcf', )):
        img_id = self.img_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(ann_info,
                                    self.with_mask,
                                    freq_groups=freq_groups)

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, ann_info, with_mask=True, freq_groups=('rcf', )):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        assert isinstance(freq_groups, tuple)
        gt_valid_idxs = {name: [] for name in freq_groups}
        gt_count = 0
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.
        if with_mask:
            gt_masks = []
            gt_mask_polys = []
            gt_poly_lens = []
        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue

            for name in freq_groups:
                if self.cat_group_idxs[ann['category_id'] -
                                       1] in self.freq_group_dict[name]:
                    gt_valid_idxs[name].append(gt_count)
            gt_count += 1

            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            gt_bboxes.append(bbox)
            gt_labels.append(self.cat2label[ann['category_id']])
            if with_mask:
                gt_masks.append(self.lvis.ann_to_mask(ann))
                mask_polys = [
                    p for p in ann['segmentation'] if len(p) >= 6
                ]  # valid polygons have >= 3 points (6 coordinates)
                poly_lens = [len(p) for p in mask_polys]
                gt_mask_polys.append(mask_polys)
                gt_poly_lens.extend(poly_lens)
        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        ann = dict(
            bboxes=gt_bboxes,
            labels=gt_labels,
            bboxes_ignore=gt_bboxes_ignore,
            gt_valid_idxs=gt_valid_idxs  # add gt_valid_idxs
        )

        if with_mask:
            ann['masks'] = gt_masks
            # poly format is not used in the current implementation
            ann['mask_polys'] = gt_mask_polys
            ann['poly_lens'] = gt_poly_lens
        return ann

    def prepare_train_img(self, idx):
        img_info = self.img_infos[idx]
        # load image
        if 'COCO' in img_info['filename']:
            img = mmcv_custom.imread(
                osp.join(
                    self.img_prefix, img_info['filename']
                    [img_info['filename'].find('COCO_val2014_') +
                     len('COCO_val2014_'):]))
        else:
            img = mmcv_custom.imread(
                osp.join(self.img_prefix, img_info['filename']))
        # corruption
        if self.corruption is not None:
            img = corrupt(img,
                          severity=self.corruption_severity,
                          corruption_name=self.corruption)
        # load proposals if necessary
        if self.proposals is not None:
            proposals = self.proposals[idx][:self.num_max_proposals]
            # TODO: Handle empty proposals properly. Currently images with
            # no proposals are just ignored, but they can be used for
            # training in concept.
            if len(proposals) == 0:
                return None
            if not (proposals.shape[1] == 4 or proposals.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposals.shape))
            if proposals.shape[1] == 5:
                scores = proposals[:, 4, None]
                proposals = proposals[:, :4]
            else:
                scores = None

        ann = self.get_ann_info(idx, freq_groups=('rcf', 'cf', 'f'))
        gt_bboxes = ann['bboxes']
        gt_labels = ann['labels']
        gt_valid_idxs = ann['gt_valid_idxs']
        if self.with_crowd:
            gt_bboxes_ignore = ann['bboxes_ignore']

        # skip the image if there is no valid gt bbox
        if len(gt_bboxes) == 0 and self.skip_img_without_anno:
            warnings.warn('Skip the image "%s" that has no valid gt bbox' %
                          osp.join(self.img_prefix, img_info['filename']))
            return None

        # extra augmentation
        if self.extra_aug is not None:
            img, gt_bboxes, gt_labels = self.extra_aug(img, gt_bboxes,
                                                       gt_labels)

        # apply transforms
        flip = True if np.random.rand() < self.flip_ratio else False
        # randomly sample a scale
        img_scale = random_scale(self.img_scales, self.multiscale_mode)
        img, img_shape, pad_shape, scale_factor = self.img_transform(
            img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        img = img.copy()
        if self.with_seg:
            gt_seg = mmcv_custom.imread(osp.join(
                self.seg_prefix, img_info['filename'].replace('jpg', 'png')),
                                        flag='unchanged')
            gt_seg = self.seg_transform(gt_seg.squeeze(), img_scale, flip)
            gt_seg = mmcv.imrescale(gt_seg,
                                    self.seg_scale_factor,
                                    interpolation='nearest')
            gt_seg = gt_seg[None, ...]
        if self.proposals is not None:
            proposals = self.bbox_transform(proposals, img_shape, scale_factor,
                                            flip)
            proposals = np.hstack([proposals, scores
                                   ]) if scores is not None else proposals

        gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
                                        flip)

        if self.with_crowd:
            gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
                                                   scale_factor, flip)
        if self.with_mask:
            gt_masks = self.mask_transform(ann['masks'], pad_shape,
                                           scale_factor, flip)

        ori_shape = (img_info['height'], img_info['width'], 3)
        not_exhaustive_category_ids = img_info['not_exhaustive_category_ids']
        neg_category_ids = img_info['neg_category_ids']
        img_meta = dict(
            ori_shape=ori_shape,
            img_shape=img_shape,
            pad_shape=pad_shape,
            scale_factor=scale_factor,
            flip=flip,
            not_exhaustive_category_ids=not_exhaustive_category_ids,
            neg_category_ids=neg_category_ids,
            # cat_group_idxs=self.cat_group_idxs,
            cat_instance_count=self.cat_instance_count,
            freq_groups=self.freq_groups,
            cat_fake_idxs=self.cat_fake_idxs,
            freq_group_dict=self.freq_group_dict,
        )

        data = dict(
            img=DC(to_tensor(img), stack=True),
            img_meta=DC(img_meta, cpu_only=True),
            gt_bboxes=DC(to_tensor(gt_bboxes)),
            gt_valid_idxs=DC(gt_valid_idxs, cpu_only=True),
        )
        if self.proposals is not None:
            data['proposals'] = DC(to_tensor(proposals))
        if self.with_label:
            data['gt_labels'] = DC(to_tensor(gt_labels))
        if self.with_crowd:
            data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore))
        if self.with_mask:
            data['gt_masks'] = DC(gt_masks, cpu_only=True)
        if self.with_seg:
            data['gt_semantic_seg'] = DC(to_tensor(gt_seg), stack=True)
        return data

    def prepare_test_img(self, idx):
        """Prepare an image for testing (multi-scale and flipping)"""
        img_info = self.img_infos[idx]
        # load image
        if 'COCO' in img_info['filename']:
            img = mmcv_custom.imread(
                osp.join(
                    self.img_prefix, img_info['filename']
                    [img_info['filename'].find('COCO_val2014_') +
                     len('COCO_val2014_'):]))
        else:
            img = mmcv_custom.imread(
                osp.join(self.img_prefix, img_info['filename']))
        # corruption
        if self.corruption is not None:
            img = corrupt(img,
                          severity=self.corruption_severity,
                          corruption_name=self.corruption)
        # load proposals if necessary
        if self.proposals is not None:
            proposal = self.proposals[idx][:self.num_max_proposals]
            if not (proposal.shape[1] == 4 or proposal.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposal.shape))
        else:
            proposal = None

        def prepare_single(img, scale, flip, proposal=None):
            _img, img_shape, pad_shape, scale_factor = self.img_transform(
                img, scale, flip, keep_ratio=self.resize_keep_ratio)
            _img = to_tensor(_img)
            _img_meta = dict(ori_shape=(img_info['height'], img_info['width'],
                                        3),
                             img_shape=img_shape,
                             pad_shape=pad_shape,
                             scale_factor=scale_factor,
                             flip=flip)
            if proposal is not None:
                if proposal.shape[1] == 5:
                    score = proposal[:, 4, None]
                    proposal = proposal[:, :4]
                else:
                    score = None
                _proposal = self.bbox_transform(proposal, img_shape,
                                                scale_factor, flip)
                _proposal = np.hstack([_proposal, score
                                       ]) if score is not None else _proposal
                _proposal = to_tensor(_proposal)
            else:
                _proposal = None
            return _img, _img_meta, _proposal

        imgs = []
        img_metas = []
        proposals = []
        for scale in self.img_scales:
            _img, _img_meta, _proposal = prepare_single(
                img, scale, False, proposal)
            imgs.append(_img)
            img_metas.append(DC(_img_meta, cpu_only=True))
            proposals.append(_proposal)
            if self.flip_ratio > 0:
                _img, _img_meta, _proposal = prepare_single(
                    img, scale, True, proposal)
                imgs.append(_img)
                img_metas.append(DC(_img_meta, cpu_only=True))
                proposals.append(_proposal)
        data = dict(img=imgs, img_meta=img_metas)
        if self.proposals is not None:
            data['proposals'] = proposals
        return data
Exemplo n.º 9
0
class LvisDataset(CustomDataset):

    CLASSES = []
    for i in range(1, 1231):
        CLASSES.append(str(i))
    CLASSES = tuple(CLASSES)

    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        self.cat_ids = self.lvis.get_cat_ids()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx):
        img_id = self.img_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(imgIds=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(self.img_infos[idx], ann_info)

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, img_info, ann_info):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, seg_map. "masks" are raw annotations and not
                decoded into binary masks.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        gt_masks_ann = []

        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            if ann.get('iscrowd', False):
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])
                gt_masks_ann.append(ann['segmentation'])

        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        seg_map = img_info['filename'].replace('jpg', 'png')

        ann = dict(bboxes=gt_bboxes,
                   labels=gt_labels,
                   bboxes_ignore=gt_bboxes_ignore,
                   masks=gt_masks_ann,
                   seg_map=seg_map)

        return ann
Exemplo n.º 10
0
    def eval_cocofied_lvis_result(self, gt_file, result_file, metric='segm'):

        def get_lvis_format_result(lvis_params, lvis_results):
            template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}"

            result_list = []
            for key, value in lvis_results.items():
                max_dets = lvis_params.max_dets
                if "AP" in key:
                    title = "Average Precision"
                    _type = "(AP)"
                else:
                    title = "Average Recall"
                    _type = "(AR)"

                if len(key) > 2 and key[2].isdigit():
                    iou_thr = (float(key[2:]) / 100)
                    iou = "{:0.2f}".format(iou_thr)
                else:
                    iou = "{:0.2f}:{:0.2f}".format(
                        lvis_params.iou_thrs[0], lvis_params.iou_thrs[-1]
                    )

                if len(key) > 2 and key[2] in ["r", "c", "f"]:
                    cat_group_name = key[2]
                else:
                    cat_group_name = "all"

                if len(key) > 2 and key[2] in ["s", "m", "l"]:
                    area_rng = key[2]
                else:
                    area_rng = "all"

                result_list.append(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value))
            return result_list

        print('load gt json')
        lvis_gt = LVIS(gt_file)
        cat_ids = lvis_gt.get_cat_ids()

        print('load pred json')
        lvis_dt = LVISResults(lvis_gt, result_file)

        print('evaluating')
        lvis_eval = LVISEval(lvis_gt, lvis_dt, metric)
        lvis_eval.params.imgIds = lvis_gt.get_img_ids()

        lvis_eval.evaluate()
        lvis_eval.accumulate()
        lvis_eval.summarize()

        # Compute per-category AP
        precisions = lvis_eval.eval['precision']
        assert len(cat_ids) == precisions.shape[2]

        results_per_category = []
        for idx, catId in enumerate(cat_ids):
            nm = lvis_gt.load_cats([catId])[0]
            precision = precisions[:, :, idx, 0]
            precision = precision[precision > -1]
            if precision.size:
                ap = np.mean(precision)
            else:
                ap = float('nan')
            results_per_category.append(
                (f'{nm["name"]}', f'{float(ap):0.3f}'))

        num_columns = min(6, len(results_per_category) * 2)
        results_flatten = list(
            itertools.chain(*results_per_category))
        headers = ['category', 'AP'] * (num_columns // 2)
        results_2d = itertools.zip_longest(*[
            results_flatten[i::num_columns]
            for i in range(num_columns)
        ])
        table_data = [headers]
        table_data += [result for result in results_2d]
        table = AsciiTable(table_data)
        print_log('\n' + table.table)

        format_summary_result_list = get_lvis_format_result(lvis_eval.params, lvis_eval.results)
        format_summary_result = "\n".join(format_summary_result_list)

        with open(f"cocofied_per-category-ap-{metric}.txt", 'w') as f:
            f.write(table.table + "\n" + format_summary_result)

        lvis_eval.print_results()