예제 #1
0
                                          img_size,
                                          shots=shots,
                                          shuffle=True,
                                          phase=args.phase,
                                          inter=(args.dataset == "inter"))
            metaclass = metadataset.metaclass

        elif args.dataset == 'pascal_voc_0712':
            if args.phase == 1:
                img_set = [('2007', 'trainval'), ('2012', 'trainval')]
            else:
                img_set = [('2007', 'trainval')]
            metadataset = MetaDataset('data/VOCdevkit',
                                      img_set,
                                      metaclass,
                                      img_size,
                                      shots=shots,
                                      shuffle=True,
                                      phase=args.phase)

        elif args.dataset == "object3d":
            metadataset = MetaDataset3D('/home/xiao/Datasets/ObjectNet3D',
                                        'ObjectNet3D_new.txt',
                                        img_size,
                                        'train',
                                        shots=shots,
                                        shuffle=True,
                                        phase=args.phase)
            metaclass = metadataset.metaclass

        elif args.dataset == "custom":
예제 #2
0
    else:
        # Second phase only use fewshot number of base and novel classes
        shots = args.shots
        if args.meta_type == 1:  #  use the first sets of all classes
            metaclass = cfg.TRAIN.ALLCLASSES_FIRST
        if args.meta_type == 2:  #  use the second sets of all classes
            metaclass = cfg.TRAIN.ALLCLASSES_SECOND
        if args.meta_type == 3:  #  use the third sets of all classes
            metaclass = cfg.TRAIN.ALLCLASSES_THIRD
    # prepare meta sets for meta training
    if args.meta_train:
        # construct the input dataset of PRN network
        img_size = 224
        metadataset = MetaDataset('data/VOCdevkit2007', [('2007', 'trainval')],
                                  metaclass,
                                  img_size,
                                  shots=shots,
                                  shuffle=True)

        metaloader = torch.utils.data.DataLoader(metadataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=0,
                                                 pin_memory=True)

    imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdb_name)
    # filter roidb for the second phase
    if args.phase == 2:
        roidb = filter_class_roidb(roidb, args.shots, imdb)
        ratio_list, ratio_index = rank_roidb_ratio(roidb)
        imdb.set_roidb(roidb)