Exemplo n.º 1
0
def build(image_set, args):
    root = Path(args.coco_path)
    assert root.exists(), f'provided COCO path {root} does not exist'
    mode = 'instances'
    PATHS = {
        "train": (root / "train", root / "annotations" / f'{mode}_train.json'),
        "val": (root / "val", root / "annotations" / f'{mode}_val.json'),
    }

    img_folder, ann_file = PATHS[image_set]
    dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks,
                            cache_mode=args.cache_mode, local_rank=get_local_rank(), local_size=get_local_size())
    return dataset
Exemplo n.º 2
0
def build(args, img_folder, ann_file, image_set, activated_class_ids,
          with_support):
    return DetectionDataset(args,
                            img_folder,
                            ann_file,
                            transforms=make_transforms(image_set),
                            support_transforms=make_support_transforms(),
                            return_masks=False,
                            activated_class_ids=activated_class_ids,
                            with_support=with_support,
                            cache_mode=args.cache_mode,
                            local_rank=get_local_rank(),
                            local_size=get_local_size())
Exemplo n.º 3
0
def build(image_set, args):
    root = Path("/home/eini/WY/project/Deformable-DETR/data/MOT15")
    # assert root.exists(), f'provided COCO path {root} does not exist'
    # mode = 'instances'
    PATHS = {
        "train": (root / "images/train", root / "labels_with_ids" / 'train.json'),
        "val": (root / "images/val", root / "labels_with_ids" / 'val.json'),
    }

    img_folder, ann_file = PATHS[image_set]
    dataset = MOT15Detection(img_folder, ann_file, transforms=make_mot15_transforms(image_set), return_masks=args.masks,
                            cache_mode=args.cache_mode, local_rank=get_local_rank(), local_size=get_local_size())
    return dataset
Exemplo n.º 4
0
def build_support_dataset(image_set, args):
    if not args.fewshot_finetune:
        assert image_set == "train"
        if args.dataset_file == 'coco':
            root = Path('data/coco/')
            img_folder = root / "train2017"
            ann_file = root / "annotations" / "instances_train2017.json"
            return SupportDataset(img_folder, ann_file,
                                  activatedClassIds=coco_base_class_ids+coco_novel_class_ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())
        if args.dataset_file == 'coco_base':
            root = Path('data/coco/')
            img_folder = root / "train2017"
            ann_file = root / "annotations" / "instances_train2017.json"
            return SupportDataset(img_folder, ann_file,
                                  activatedClassIds=coco_base_class_ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())

        if args.dataset_file == 'voc':
            root = Path('data/voc')
            img_folder = root / "images"
            ann_files = [root / "annotations" / 'pascal_train2007.json',
                         root / "annotations" / 'pascal_val2007.json',
                         root / "annotations" / 'pascal_train2012.json',
                         root / "annotations" / 'pascal_val2012.json']
            return SupportDataset(img_folder, ann_files,
                                  activatedClassIds=list(range(1, 20+1)),
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())
        if args.dataset_file == 'voc_base1':
            root = Path('data/voc')
            img_folder = root / "images"
            ann_files = [root / "annotations" / 'pascal_train2007.json',
                         root / "annotations" / 'pascal_val2007.json',
                         root / "annotations" / 'pascal_train2012.json',
                         root / "annotations" / 'pascal_val2012.json']
            return SupportDataset(img_folder, ann_files,
                                  activatedClassIds=voc_base1_class_ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())
        if args.dataset_file == 'voc_base2':
            root = Path('data/voc')
            img_folder = root / "images"
            ann_files = [root / "annotations" / 'pascal_train2007.json',
                         root / "annotations" / 'pascal_val2007.json',
                         root / "annotations" / 'pascal_train2012.json',
                         root / "annotations" / 'pascal_val2012.json']
            return SupportDataset(img_folder, ann_files,
                                  activatedClassIds=voc_base2_class_ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())
        if args.dataset_file == 'voc_base3':
            root = Path('data/voc')
            img_folder = root / "images"
            ann_files = [root / "annotations" / 'pascal_train2007.json',
                         root / "annotations" / 'pascal_val2007.json',
                         root / "annotations" / 'pascal_train2012.json',
                         root / "annotations" / 'pascal_val2012.json']
            return SupportDataset(img_folder, ann_files,
                                  activatedClassIds=voc_base3_class_ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())

    else:
        # After Fewshot Fine-tuning, we use the support dataset that was used for few-shot fine-tuning as the support
        # dataset for inference (to generate category codes).
        assert image_set == "fewshot"

        if args.dataset_file == 'coco_base':
            root = Path('data/coco_fewshot')
            img_folder = root.parent / 'coco' / "train2017"
            ids = (coco_base_class_ids + coco_novel_class_ids)
            ids.sort()
            ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json'
            return SupportDataset(img_folder, str(ann_file),
                                  activatedClassIds=ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())

        if args.dataset_file == 'voc_base1':
            root = Path('data/voc_fewshot_split1')
            img_folder = root.parent / 'voc' / "images"
            ids = list(range(1, 20+1))
            ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json'
            return SupportDataset(img_folder, str(ann_file),
                                  activatedClassIds=ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())

        if args.dataset_file == 'voc_base2':
            root = Path('data/voc_fewshot_split2')
            img_folder = root.parent / 'voc' / "images"
            ids = list(range(1, 20+1))
            ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json'
            return SupportDataset(img_folder, str(ann_file),
                                  activatedClassIds=ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())

        if args.dataset_file == 'voc_base3':
            root = Path('data/voc_fewshot_split3')
            img_folder = root.parent / 'voc' / "images"
            ids = list(range(1, 20+1))
            ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json'
            return SupportDataset(img_folder, str(ann_file),
                                  activatedClassIds=ids,
                                  transforms=make_support_transforms(),
                                  cache_mode=args.cache_mode,
                                  local_rank=get_local_rank(),
                                  local_size=get_local_size())

    raise ValueError
Exemplo n.º 5
0
def build(args, image_set, activated_class_ids, with_support=True):
    assert image_set == "fewshot"
    activated_class_ids.sort()

    if args.dataset_file in ['coco_base']:
        root = Path('data/coco_fewshot')
        img_folder = root.parent / 'coco' / "train2017"
        ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json'
        return DetectionDataset(args,
                                img_folder,
                                str(ann_file),
                                transforms=make_transforms(),
                                support_transforms=make_support_transforms(),
                                return_masks=False,
                                activated_class_ids=activated_class_ids,
                                with_support=with_support,
                                cache_mode=args.cache_mode,
                                local_rank=get_local_rank(),
                                local_size=get_local_size())

    if args.dataset_file == "voc_base1":
        root = Path('data/voc_fewshot_split1')
        img_folder = root.parent / 'voc' / "images"
        ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json'
        return DetectionDataset(args,
                                img_folder,
                                str(ann_file),
                                transforms=make_transforms(),
                                support_transforms=make_support_transforms(),
                                return_masks=False,
                                activated_class_ids=activated_class_ids,
                                with_support=with_support,
                                cache_mode=args.cache_mode,
                                local_rank=get_local_rank(),
                                local_size=get_local_size())

    if args.dataset_file == "voc_base2":
        root = Path('data/voc_fewshot_split2')
        img_folder = root.parent / 'voc' / "images"
        ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json'
        return DetectionDataset(args,
                                img_folder,
                                str(ann_file),
                                transforms=make_transforms(),
                                support_transforms=make_support_transforms(),
                                return_masks=False,
                                activated_class_ids=activated_class_ids,
                                with_support=with_support,
                                cache_mode=args.cache_mode,
                                local_rank=get_local_rank(),
                                local_size=get_local_size())

    if args.dataset_file == "voc_base3":
        root = Path('data/voc_fewshot_split3')
        img_folder = root.parent / 'voc' / "images"
        ann_file = root / f'seed{args.fewshot_seed}' / f'{args.num_shots}shot.json'
        return DetectionDataset(args,
                                img_folder,
                                str(ann_file),
                                transforms=make_transforms(),
                                support_transforms=make_support_transforms(),
                                return_masks=False,
                                activated_class_ids=activated_class_ids,
                                with_support=with_support,
                                cache_mode=args.cache_mode,
                                local_rank=get_local_rank(),
                                local_size=get_local_size())

    raise ValueError