Exemplo n.º 1
0
    def _download(self):
        _fpath = os.path.join(MyPath.db_root_dir(), self.FILE)

        if os.path.isfile(_fpath):
            print('Files already downloaded')
            return
        else:
            print('Downloading ' + self.URL + ' to ' + _fpath)

            def _progress(count, block_size, total_size):
                sys.stdout.write('\r>> %s %.1f%%' %
                                 (_fpath, float(count * block_size) /
                                  float(total_size) * 100.0))
                sys.stdout.flush()

            urllib.request.urlretrieve(self.URL, _fpath, _progress)

        # extract file
        cwd = os.getcwd()
        print('\nExtracting tar file')
        tar = tarfile.open(_fpath)
        os.chdir(MyPath.db_root_dir())
        tar.extractall()
        tar.close()
        os.chdir(cwd)
        print('Done!')
Exemplo n.º 2
0
    def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', 
                    transform=None):
        super(ImageNetSubset, self).__init__()

        self.root = os.path.join(root, 'ILSVRC2012_img_%s' %(split))
        self.transform = transform
        self.split = split

        # Read the subset of classes to include (sorted)
        with open(subset_file, 'r') as f:
            result = f.read().splitlines()
        subdirs, class_names = [], []
        for line in result:
            subdir, class_name = line.split(' ', 1)
            subdirs.append(subdir)
            class_names.append(class_name)

        # Gather the files (sorted)
        imgs = []
        for i, subdir in enumerate(subdirs):
            subdir_path = os.path.join(self.root, subdir)
            files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG')))
            for f in files:
                imgs.append((f, i)) 
        self.imgs = imgs 
        self.classes = class_names
    
	# Resize
        self.resize = tf.Resize(256)
    def __init__(self,
                 root=MyPath.db_root_dir('sewer'),
                 split='Training',
                 transform=None):
        super(SewerNet, self).__init__()

        self.root = os.path.join(root, split)
        self.transform = transform

        subdirs = []
        for name in os.listdir(self.root):
            subdirs.append(name)

        print(split)
        print(self.root)
        print(subdirs)

        imgs = []
        for i, subdir in enumerate(subdirs):
            subdir_path = os.path.join(self.root, subdir)
            files = sorted(glob(os.path.join(self.root, subdir, '*.jpg')))
            for f in files:
                imgs.append((f, i))
        self.imgs = imgs
        self.classes = subdirs

        self.resize = tf.Resize(128)  #256
Exemplo n.º 4
0
def get_val_dataset(p,
                    transform=None,
                    to_neighbors_dataset=False,
                    to_similarity_dataset=False,
                    use_negatives=False,
                    use_simpred=False):
    # Base dataset
    if p['val_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=False, transform=transform, download=True)

    elif p['val_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=False, transform=transform, download=True)

    elif p['val_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split='test', transform=transform, download=True)

    elif p['train_db_name'] in [
            'impact_kb', 'impact_full_balanced', 'impact_full_imbalanced',
            'hdi_balanced', 'hdi_imbalanced', 'tobacco3482', 'rvl-cdip',
            'wpi_demo'
    ]:
        from data.imagefolderwrapper import ImageFolderWrapper
        root = MyPath.db_root_dir(p['train_db_name'])
        dataset = ImageFolderWrapper(root, split="test", transform=transform)

    elif p['val_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='val', transform=transform)

    elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' % (p['val_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file,
                                 split='val',
                                 transform=transform)

    else:
        raise ValueError('Invalid validation dataset {}'.format(
            p['val_db_name']))

    # Wrap into other dataset (__getitem__ changes)
    if to_neighbors_dataset:  # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        knn_indices = np.load(p['topk_neighbors_val_path'])

        if use_negatives:
            kfn_indices = np.load(p['topk_furthest_val_path'])
        else:
            kfn_indices = None

        dataset = NeighborsDataset(dataset, knn_indices, kfn_indices,
                                   use_simpred, 5, 5)  # Only use 5
    elif to_similarity_dataset:  # Dataset returns an image and another random image.
        from data.custom_dataset import SimilarityDataset
        dataset = SimilarityDataset(dataset)

    return dataset
    def __init__(self,
                 root=MyPath.db_root_dir('omniglot'),
                 train=True,
                 transform=None,
                 download=True):

        super(Omniglot, self).__init__()
        self.root = root
        self.transform = transform
        self.train = train  # training set or test set
        if train:
            self.split = 'train'
        else:
            self.split = 'small1'
        ds, info = tfds.load("omniglot",
                             split=self.split,
                             shuffle_files=True,
                             with_info=True)
        ds_np = tfds.as_numpy(ds)
        self.classes = info.features["label"].names

        self.data = []
        self.targets = []
        #i=0

        # now load the picked numpy arrays
        for ex in ds_np:
            #_img = cv2.resize(ex["image"], (32, 32))
            self.data.append(ex["image"])
            self.targets.append(ex["label"])
Exemplo n.º 6
0
    def __init__(self,
                 root=MyPath.db_root_dir('hint'),
                 split='train',
                 transform=None):

        super(HINT, self).__init__()
        self.root = root
        self.img_dir = root + 'symbol_images/'
        self.transform = transform
        self.classes = SYMBOLS

        # split = 'train'
        self.split = split

        dataset = json.load(open(root + 'expr_%s.json' % split))
        # dataset = [x for x in dataset if len(x['expr']) <= 15]
        dataset = [(x, SYM2ID(y)) for sample in dataset
                   for x, y in zip(sample['img_paths'], sample['expr'])]
        label2data = {i: [] for i in range(len(self.classes))}
        for img, label in dataset:
            label2data[label].append((img, label))
        dataset = []
        random.seed(777)
        n_sample_per_class = 5000 if split == 'train' else 500
        for label, data in label2data.items():
            dataset.extend(random.choices(data, k=n_sample_per_class))
        random.shuffle(dataset)
        print(dataset[:10])

        print(sorted(Counter([x[1] for x in dataset]).items()))
        self.dataset = dataset
Exemplo n.º 7
0
    def __init__(self,
                 root=MyPath.db_root_dir('cct20'),
                 train=True,
                 transform=None,
                 download=False):


        super(CCT20, self).__init__()
        self.root = root
        self.transform = transform
        self.train = train  # training set or test set
        self.classes = ['bobcat', 'opossum', 'empty', 'coyote', 'raccoon', 'bird', 'dog', 'cat',
                        'squirrel', 'rabbit', 'skunk', 'rodent', 'badger', 'deer', 'car', 'fox']

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.val_list

        self._load_meta()

        self.img_paths = []
        self.targets = []

        for folder in downloaded_list:
            folder_path = os.path.join(self.root, self.base_folder, folder)
            folder_files = os.listdir(folder_path)
            img_type = '.jpg'
            for file in folder_files:
                if file.endswith(img_type):
                    #img_data = imread(folder_path + '/' + file)
                    img_path = folder_path + '/' + file
                    species = '_'.join(file.split('_')[2:]).replace(img_type, '')
                    self.img_paths.append(img_path)
                    self.targets.append(self.class_to_idx[species])
Exemplo n.º 8
0
    def __init__(self, root=MyPath.db_root_dir('cub'), train=True, transform=None):

        super(CUB, self).__init__()
        self.root = root
        self.transform = transform
        self.train = train
        self.imgNames = []
        self.classes = []

        with open(root + 'classes.txt', 'r') as classesFile:
            for line in classesFile.readlines():
                self.classes.append(line.strip().split(' ')[1].split('.')[1])

        isTrainList = []
        with open(root + 'train_test_split.txt', 'r') as splitFile:
            for line in splitFile.readlines():
                isTrainList.append(int(line.strip().split(' ')[1]))

        classLabels = []
        with open(root + 'image_class_labels.txt', 'r') as labelsFile:
            for line in labelsFile.readlines():
                classLabels.append(int(line.strip().split(' ')[1]) - 1)

        imagePaths = []
        with open(root + 'images.txt', 'r') as imagesFile:
            for line in imagesFile.readlines():
                imagePaths.append(root + 'images/' + line.strip().split(' ')[1])

        self.imagePaths = []
        self.classLabels = []
        for i, isTrain in enumerate(isTrainList):
            if isTrain == train:
                self.imagePaths.append(imagePaths[i])
                self.classLabels.append(classLabels[i])
 def __init__(self,
              root=MyPath.db_root_dir('imagenet'),
              split='train',
              transform=None):
     super(ImageNet, self).__init__(root=os.path.join(root, split),
                                    transform=None)
     self.transform = transform
     self.split = split
     self.resize = tf.Resize(256)
Exemplo n.º 10
0
 def __init__(self,
              root=MyPath.db_root_dir('batsnet'),
              split='train',
              transform=None):
     super(batsnet,
           self).__init__(root=os.path.join(root, 'batsnet_%s' % (split)),
                          transform=None)
     self.transform = transform
     self.split = split
     self.resize = tf.Resize(360)
Exemplo n.º 11
0
 def __init__(self, root=MyPath.db_root_dir('cifar-20'), train=True, transform=None, 
                 download=False):
     super(CIFAR20, self).__init__(root, train=train,transform=transform,
                                     download=download)
     # Remap classes from cifar-100 to cifar-20
     new_ = self.targets
     for idx, target in enumerate(self.targets):
         new_[idx] = _cifar100_to_cifar20(target)
     self.targets = new_
     self.classes = ['aquatic mammals', 'fish', 'flowers', 'food containers', 'fruit and vegetables', 'household electrical devices', 'househould furniture', 'insects', 'large carnivores', 'large man-made outdoor things', 'large natural outdoor scenes', 'large omnivores and herbivores', 'medium-sized mammals', 'non-insect invertebrates', 'people', 'reptiles', 'small mammals', 'trees', 'vehicles 1', 'vehicles 2']
Exemplo n.º 12
0
    def _download(self):
        _fpath = os.path.join(MyPath.db_root_dir(), self.FILE)

        if os.path.isfile(_fpath):
            print('Files already downloaded')
            return
        else:
            print('Downloading from google drive')
            mkdir_if_missing(os.path.dirname(_fpath))
            download_file_from_google_drive(self.GOOGLE_DRIVE_ID, _fpath)

        # extract file
        cwd = os.getcwd()
        print('\nExtracting tar file')
        tar = tarfile.open(_fpath)
        os.chdir(MyPath.db_root_dir())
        tar.extractall()
        tar.close()
        os.chdir(cwd)
        print('Done!')
Exemplo n.º 13
0
    def __init__(self,
                 root=MyPath.db_root_dir('cifar-10'),
                 train=True,
                 transform=None,
                 download=False):

        super(CIFAR10, self).__init__()
        self.root = root
        self.transform = transform
        self.train = train  # training set or test set
        self.classes = [
            'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
            'ship', 'truck'
        ]

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            print(self.root)

            file_path = os.path.join(self.root, self.base_folder, file_name)
            print(file_path)
            with open(file_path, 'rb') as f:
                if sys.version_info[0] == 2:
                    entry = pickle.load(f)
                else:
                    entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()
Exemplo n.º 14
0
    def __init__(self,
                 root=MyPath.db_root_dir('bhp_kenya'),
                 train=True,
                 transform=None,
                 download=False):

        super(BHPKenya, self).__init__()
        self.root = root
        self.transform = transform
        self.train = train  # training set or test set

        self.classes = [
            'wildebeest', 'gazelle_thomsons', 'shoats', 'cattle', 'zebra',
            'impala', 'topi', 'warthog', 'hyena_spotted', 'giraffe',
            'elephant', 'hare', 'dikdik', 'hippopotamus', 'jackal',
            'gazelle_grants', 'baboon', 'buffalo', 'eland',
            'mongoose_white_tailed', 'mongoose_banded', 'vervet_monkey',
            'springhare', 'bateared_fox', 'waterbuck', 'hartebeest_cokes',
            'domestic_dog', 'lion', 'aardvark', 'genet', 'serval',
            'mongoose_other', 'porcupine', 'aardwolf', 'oribi', 'other_bird',
            'ostrich', 'bustard_white_bellied', 'guineafowl'
        ]

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.val_list

        self._load_meta()

        self.img_paths = []
        self.targets = []

        for folder in downloaded_list:
            folder_path = os.path.join(self.root, self.base_folder, folder)
            folder_files = os.listdir(folder_path)
            img_type = '.jpg'
            for file in folder_files:
                if file.endswith(img_type):
                    #img_data = imread(folder_path + '/' + file)
                    img_path = folder_path + '/' + file
                    species = '_'.join(file.split('_')[4:]).replace(
                        img_type, '')
                    if species in self.classes:
                        self.img_paths.append(img_path)
                        self.targets.append(self.class_to_idx[species])
Exemplo n.º 15
0
    def __init__(self, root=MyPath.db_root_dir('stl-10'),
                 split='train', folds=None, transform=None,
                 download=False):
        super(STL10, self).__init__()
        self.root = root
        self.transform = transform
        self.split = verify_str_arg(split, "split", self.splits)
        self.folds = self._verify_folds(folds)
        if download:
            self.download()
        elif not self._check_integrity():
            raise RuntimeError(
                'Dataset not found or corrupted. '
                'You can use download=True to download it')

        # now load the picked numpy arrays
        if self.split == 'train':
            self.data, self.labels = self.__loadfile(
                self.train_list[0][0], self.train_list[1][0])
            self.__load_folds(folds)

        elif self.split == 'train+unlabeled':
            self.data, self.labels = self.__loadfile(
                self.train_list[0][0], self.train_list[1][0])
            self.__load_folds(folds)
            unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
            self.data = np.concatenate((self.data, unlabeled_data))
            self.labels = np.concatenate(
                (self.labels, np.asarray([-1] * unlabeled_data.shape[0])))

        elif self.split == 'unlabeled':
            self.data, _ = self.__loadfile(self.train_list[2][0])
            self.labels = np.asarray([-1] * self.data.shape[0])
        else:  # self.split == 'test':
            self.data, self.labels = self.__loadfile(
                self.test_list[0][0], self.test_list[1][0])

        class_file = os.path.join(
            self.root, self.base_folder, self.class_names_file)
        if os.path.isfile(class_file):
            with open(class_file) as f:
                self.classes = f.read().splitlines()

        if self.split == 'train': # Added this to be able to filter out fp from neighbors
            self.targets = self.labels
Exemplo n.º 16
0
    def __init__(self,
                 root=MyPath.db_root_dir('mnist'),
                 train=True,
                 transform=None,
                 download=False):
        super(MNIST, self).__init__()
        self.root = root
        self.transform = transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file
        self.data, self.targets = torch.load(
            os.path.join(self.processed_folder, data_file))
Exemplo n.º 17
0
    def __init__(self,
                 root=MyPath.db_root_dir('bird'),
                 split='train',
                 transform=None):
        super(Birds, self).__init__()

        self.split = split

        self.transform = transform

        self.resize = tf.Resize(256)

        path = untar_data(URLs.CUB_200_2011)

        self.files = get_image_files(path / "images")
        self.label = dict(
            sorted(enumerate(set(self.files.map(self.label_func))),
                   key=itemgetter(1)))
        self.labels = dict([(value, key) for key, value in self.label.items()])
        self.df = pd.read_csv(path / 'train_test_split.txt', delimiter=' ')

        if self.split == 'train':
            self.file_index = [
                i['1'] for i in self.df.to_dict('records') if i['0'] == 1
            ]
            self.Files = [
                i for i in self.files if self.splitter(i) in self.file_index
            ]

        else:
            self.file_index = [
                i['1'] for i in self.df.to_dict('records') if i['0'] == 0
            ]
            self.Files = [
                i for i in self.files if self.splitter(i) in self.file_index
            ]
    def __init__(self,
                 root=MyPath.db_root_dir('pascal-voc'),
                 train=True,
                 transform=None):

        super(PASCALVOC, self).__init__()
        self.root = root
        self.transform = transform
        self.train = train
        self.imgNames = []
        self.classNames = []
        self.classes = [
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
            'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
            'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]
        self.classDict = {}
        for i, name in enumerate(self.classes):
            self.classDict[name] = i  # + 1

        with open(self.root + '/bndBoxImageLabels.txt', 'r') as labelFile:
            for line in labelFile.readlines():
                self.imgNames.append(line.split(' ')[0])
                self.classNames.append(line.split(' ')[1].strip())
Exemplo n.º 19
0
    def __init__(self,
                 root=MyPath.db_root_dir('partnet', ''),
                 split='train',
                 type='chair',
                 transform=None):

        super(PARTNET, self).__init__()
        self.root = MyPath.db_root_dir('partnet', type)
        self.img_dir = self.root + split
        self.transform = transform
        self.images = []
        self.targets = []
        # split = 'train'
        self.split = split
        #self.part_order = {'bag_body': 0, 'handle': 1, 'shoulder_strap': 2}
        #self.classes = ['bag_body', 'handle', 'shoulder_strap']
        #self.part_order = {'headboard': 0, 'bed_sleep_area': 1, 'bed_frame_horizontal_surface':2, 'bed_side_surface_panel': 3, 'bed_post': 4, 'leg': 5, 'surface_base': 6, 'ladder':7}
        #self.classes = ['headboard', 'bed_sleep_area', 'bed_frame_horizontal_surface', 'bed_side_surface_panel', 'bed_post', 'leg', 'surface_base', 'ladder']
        #self.part_order = {'tabletop': 0, 'drawer': 1, 'cabinet_door': 2, 'side_panel': 3, 'bottom_panel': 4, 'central_support': 5, 'leg': 6, 'shelf': 7, 'leg_bar': 8, 'pedestal': 9,\
        #    'chair_head': 10, 'back_surface': 11, 'back_frame_vertical_bar': 12, 'back_frame_horizontal_bar': 13, 'chair_seat': 14, 'chair_arm': 15, 'arm_sofa_style': 16, 'arm_near_vertical_bar': 17, 'arm_horizontal_bar': 18}
        #self.classes = ['tabletop', 'drawer', 'cabinet_door', 'side_panel', 'bottom_panel', 'central_support', 'leg', 'shelf', 'leg_bar', 'pedestal',\
        #'chair_head', 'back_surface', 'back_frame_vertical_bar', 'back_frame_horizontal_bar', 'chair_seat', 'chair_arm', 'arm_sofa_style', 'arm_near_vertical_bar', 'arm_horizontal_bar']
        if type == 'chair':
            self.part_order = {
                'chair_head': 0,
                'back_surface': 1,
                'back_frame_vertical_bar': 2,
                'back_frame_horizontal_bar': 3,
                'chair_seat': 4,
                'chair_arm': 5,
                'arm_sofa_style': 6,
                'arm_near_vertical_bar': 7,
                'arm_horizontal_bar': 8,
                'central_support': 9,
                'leg': 10,
                'leg_bar': 11,
                'pedestal': 12
            }
            self.classes = [
                'chair_head', 'back_surface', 'back_frame_vertical_bar',
                'back_frame_horizontal_bar', 'chair_seat', 'chair_arm',
                'arm_sofa_style', 'arm_near_vertical_bar',
                'arm_horizontal_bar', 'central_support', 'leg', 'leg_bar',
                'pedestal'
            ]
        if type == 'table':
            self.part_order = {
                'tabletop': 0,
                'drawer': 1,
                'cabinet_door': 2,
                'side_panel': 3,
                'bottom_panel': 4,
                'central_support': 5,
                'leg': 6,
                'shelf': 7,
                'leg_bar': 8,
                'pedestal': 9
            }
            self.classes = [
                'tabletop', 'drawer', 'cabinet_door', 'side_panel',
                'bottom_panel', 'central_support', 'leg', 'shelf', 'leg_bar',
                'pedestal'
            ]
        if type == 'bed':
            self.part_order = {
                'headboard': 0,
                'bed_sleep_area': 1,
                'bed_frame_horizontal_surface': 2,
                'bed_side_surface_panel': 3,
                'bed_post': 4,
                'leg': 5,
                'surface_base': 6,
                'ladder': 7
            }
            self.classes = [
                'headboard', 'bed_sleep_area', 'bed_frame_horizontal_surface',
                'bed_side_surface_panel', 'bed_post', 'leg', 'surface_base',
                'ladder'
            ]
        if type == 'bag':
            self.part_order = {'bag_body': 0, 'handle': 1, 'shoulder_strap': 2}
            self.classes = ['bag_body', 'handle', 'shoulder_strap']
        for di in next(os.walk(self.img_dir))[1]:
            if di.endswith("npy"): continue
            else:
                images = os.listdir(os.path.join(self.img_dir, di))
                for img in images:
                    if (not img == "0.png") and (img.endswith("occluded.png")):

                        name = img.replace("occluded.png", "")

                        f = open(os.path.join(self.img_dir, di,
                                              "cat2idx.json"))
                        part_dict = json.load(f)
                        for (key, val) in part_dict.items():
                            for value in val:
                                if name == str(value):
                                    self.images.append(
                                        os.path.join(self.img_dir, di, img))
                                    self.targets.append(self.part_order[key])
Exemplo n.º 20
0
    def __init__(
        self,
        root=MyPath.db_root_dir('PASCAL_MT'),
        download=True,
        split='val',
        transform=None,
        area_thres=0,
        retname=True,
        overfit=False,
        do_edge=True,
        do_human_parts=False,
        do_semseg=False,
        do_normals=False,
        do_sal=False,
        num_human_parts=6,
    ):

        self.root = root
        if download:
            self._download()

        image_dir = os.path.join(self.root, 'JPEGImages')
        self.transform = transform

        if isinstance(split, str):
            self.split = [split]
        else:
            split.sort()
            self.split = split

        self.area_thres = area_thres
        self.retname = retname

        # Edge Detection
        self.do_edge = do_edge
        self.edges = []
        edge_gt_dir = os.path.join(self.root, 'pascal-context', 'trainval')

        # Semantic Segmentation
        self.do_semseg = do_semseg
        self.semsegs = []

        # Human Part Segmentation
        self.do_human_parts = do_human_parts
        part_gt_dir = os.path.join(self.root, 'human_parts')
        self.parts = []
        self.human_parts_category = 15
        print(PROJECT_ROOT_DIR)
        self.cat_part = json.load(
            open(
                os.path.join(PROJECT_ROOT_DIR,
                             'data/db_info/pascal_part.json'), 'r'))
        self.cat_part["15"] = self.HUMAN_PART[num_human_parts]
        self.parts_file = os.path.join(
            os.path.join(self.root, 'ImageSets', 'Parts'),
            ''.join(self.split) + '.txt')

        # Surface Normal Estimation
        self.do_normals = do_normals
        _normal_gt_dir = os.path.join(self.root, 'normals_distill')
        self.normals = []
        if self.do_normals:
            with open(
                    os.path.join(PROJECT_ROOT_DIR,
                                 'data/db_info/nyu_classes.json')) as f:
                cls_nyu = json.load(f)
            with open(
                    os.path.join(PROJECT_ROOT_DIR,
                                 'data/db_info/context_classes.json')) as f:
                cls_context = json.load(f)

            self.normals_valid_classes = []
            for cl_nyu in cls_nyu:
                if cl_nyu in cls_context and cl_nyu != 'unknown':
                    self.normals_valid_classes.append(cls_context[cl_nyu])

            # Custom additions due to incompatibilities
            self.normals_valid_classes.append(cls_context['tvmonitor'])

        # Saliency
        self.do_sal = do_sal
        _sal_gt_dir = os.path.join(self.root, 'sal_distill')
        self.sals = []

        # train/val/test splits are pre-cut
        _splits_dir = os.path.join(self.root, 'ImageSets', 'Context')

        self.im_ids = []
        self.images = []

        print("Initializing dataloader for PASCAL {} set".format(''.join(
            self.split)))
        for splt in self.split:
            with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')),
                      "r") as f:
                lines = f.read().splitlines()

            for ii, line in enumerate(lines):
                # Images
                _image = os.path.join(image_dir, line + ".jpg")
                assert os.path.isfile(_image)
                self.images.append(_image)
                self.im_ids.append(line.rstrip('\n'))

                # Edges
                _edge = os.path.join(edge_gt_dir, line + ".mat")
                assert os.path.isfile(_edge)
                self.edges.append(_edge)

                # Semantic Segmentation
                _semseg = self._get_semseg_fname(line)
                assert os.path.isfile(_semseg)
                self.semsegs.append(_semseg)

                # Human Parts
                _human_part = os.path.join(
                    part_gt_dir, line + ".mat"
                )  # issue: self.root and part_gt_dir is the same and will be joined
                assert os.path.isfile(_human_part)
                self.parts.append(_human_part)

                _normal = os.path.join(_normal_gt_dir,
                                       line + ".png")  # self.root,
                assert os.path.isfile(_normal)
                self.normals.append(_normal)

                _sal = os.path.join(_sal_gt_dir, line + ".png")  # self.root,
                assert os.path.isfile(_sal)
                self.sals.append(_sal)

        if self.do_edge:
            assert (len(self.images) == len(self.edges))
        if self.do_human_parts:
            assert (len(self.images) == len(self.parts))
        if self.do_semseg:
            assert (len(self.images) == len(self.semsegs))
        if self.do_normals:
            assert (len(self.images) == len(self.normals))
        if self.do_sal:
            assert (len(self.images) == len(self.sals))

        if not self._check_preprocess_parts():
            print(
                'Pre-processing PASCAL dataset for human parts, this will take long, but will be done only once.'
            )
            self._preprocess_parts()

        if self.do_human_parts:
            # Find images which have human parts
            self.has_human_parts = []
            for ii in range(len(self.im_ids)):
                if self.human_parts_category in self.part_obj_dict[
                        self.im_ids[ii]]:
                    self.has_human_parts.append(1)
                else:
                    self.has_human_parts.append(0)

            # If the other tasks are disabled, select only the images that contain human parts, to allow batching
            if not self.do_edge and not self.do_semseg and not self.do_sal and not self.do_normals:
                print('Ignoring images that do not contain human parts')
                for i in range(len(self.parts) - 1, -1, -1):
                    if self.has_human_parts[i] == 0:
                        del self.im_ids[i]
                        del self.images[i]
                        del self.parts[i]
                        del self.has_human_parts[i]
            print('Number of images with human parts: {:d}'.format(
                np.sum(self.has_human_parts)))

        #  Overfit to n_of images
        if overfit:
            n_of = 64
            self.images = self.images[:n_of]
            self.im_ids = self.im_ids[:n_of]
            if self.do_edge:
                self.edges = self.edges[:n_of]
            if self.do_semseg:
                self.semsegs = self.semsegs[:n_of]
            if self.do_human_parts:
                self.parts = self.parts[:n_of]
            if self.do_normals:
                self.normals = self.normals[:n_of]
            if self.do_sal:
                self.sals = self.sals[:n_of]

        # Display stats
        print('Number of dataset images: {:d}'.format(len(self.images)))
Exemplo n.º 21
0
    def __init__(
        self,
        root=MyPath.db_root_dir('NYUD_MT'),
        download=True,
        split='val',
        transform=None,
        retname=True,
        overfit=False,
        do_edge=False,
        do_semseg=False,
        do_normals=False,
        do_depth=False,
    ):

        self.root = root

        if download:
            self._download()

        self.transform = transform

        if isinstance(split, str):
            self.split = [split]
        else:
            split.sort()
            self.split = split

        self.retname = retname

        # Original Images
        self.im_ids = []
        self.images = []
        _image_dir = os.path.join(root, 'images')

        # Edge Detection
        self.do_edge = do_edge
        self.edges = []
        _edge_gt_dir = os.path.join(root, 'edge')

        # Semantic segmentation
        self.do_semseg = do_semseg
        self.semsegs = []
        _semseg_gt_dir = os.path.join(root, 'segmentation')

        # Surface Normals
        self.do_normals = do_normals
        self.normals = []
        _normal_gt_dir = os.path.join(root, 'normals')

        # Depth
        self.do_depth = do_depth
        self.depths = []
        _depth_gt_dir = os.path.join(root, 'depth')

        # train/val/test splits are pre-cut
        _splits_dir = os.path.join(root, 'gt_sets')

        print('Initializing dataloader for NYUD {} set'.format(''.join(
            self.split)))
        for splt in self.split:
            with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')),
                      'r') as f:
                lines = f.read().splitlines()

            for ii, line in enumerate(lines):

                # Images
                _image = os.path.join(_image_dir, line + '.jpg')
                assert os.path.isfile(_image)
                self.images.append(_image)
                self.im_ids.append(line.rstrip('\n'))

                # Edges
                _edge = os.path.join(self.root, _edge_gt_dir, line + '.npy')
                assert os.path.isfile(_edge)
                self.edges.append(_edge)

                # Semantic Segmentation
                _semseg = os.path.join(self.root, _semseg_gt_dir,
                                       line + '.png')
                assert os.path.isfile(_semseg)
                self.semsegs.append(_semseg)

                # Surface Normals
                _normal = os.path.join(self.root, _normal_gt_dir,
                                       line + '.npy')
                assert os.path.isfile(_normal)
                self.normals.append(_normal)

                # Depth Prediction
                _depth = os.path.join(self.root, _depth_gt_dir, line + '.npy')
                assert os.path.isfile(_depth)
                self.depths.append(_depth)

        if self.do_edge:
            assert (len(self.images) == len(self.edges))
        if self.do_semseg:
            assert (len(self.images) == len(self.semsegs))
        if self.do_depth:
            assert (len(self.images) == len(self.depths))
        if self.do_normals:
            assert (len(self.images) == len(self.normals))

        # Uncomment to overfit to one image
        if overfit:
            n_of = 64
            self.images = self.images[:n_of]
            self.im_ids = self.im_ids[:n_of]

        # Display stats
        print('Number of dataset images: {:d}'.format(len(self.images)))
def eval_edge_predictions(p, database, save_dir):
    """ The edge are evaluated through seism """

    print(
        'Evaluate the edge prediction using seism ... This can take a while ...'
    )

    # DataLoaders
    if database == 'PASCALContext':
        from data.pascal_context import PASCALContext
        split = 'val'
        db = PASCALContext(split=split,
                           do_edge=True,
                           do_human_parts=False,
                           do_semseg=False,
                           do_normals=False,
                           do_sal=True,
                           overfit=False)

    else:
        raise NotImplementedError

    # First check if all files are there
    files = glob.glob(os.path.join(save_dir, 'edge/*png'))

    assert (len(files) == len(db))

    # rsync the results to the seism root
    print('Rsync the results to the seism root ...')
    exp_name = database + '_' + p['setup'] + '_' + p['model']
    seism_root = MyPath.seism_root()
    result_dir = os.path.join(seism_root,
                              'datasets/%s/%s/' % (database, exp_name))
    mkdir_if_missing(result_dir)
    os.system('rsync -a %s %s' %
              (os.path.join(save_dir, 'edge/*'), result_dir))
    print('Done ...')

    v = list(np.arange(0.01, 1.00, 0.01))
    parameters_location = os.path.join(seism_root,
                                       'parameters/%s.txt' % (exp_name))
    with open(parameters_location, 'w') as f:
        for l in v:
            f.write('%.2f\n' % (l))

    # generate a seism script that we will run.
    print('Generate seism script to perform the evaluation ...')
    seism_base = os.path.join(PROJECT_ROOT_DIR,
                              'evaluation/seism/pr_curves_base.m')
    with open(seism_base) as f:
        seism_file = f.readlines()
    seism_file = [line.strip() for line in seism_file]
    output_file = [seism_file[0]]

    ## Add experiments parameters (TODO)
    output_file += [
        'addpath(\'%s\')' % (os.path.join(seism_root, 'src/scripts/'))
    ]
    output_file += [
        'addpath(\'%s\')' % (os.path.join(seism_root, 'src/misc/'))
    ]
    output_file += [
        'addpath(\'%s\')' % (os.path.join(seism_root, 'src/tests/'))
    ]
    output_file += [
        'addpath(\'%s\')' % (os.path.join(seism_root, 'src/gt_wrappers/'))
    ]
    output_file += ['addpath(\'%s\')' % (os.path.join(seism_root, 'src/io/'))]
    output_file += [
        'addpath(\'%s\')' % (os.path.join(seism_root, 'src/measures/'))
    ]
    output_file += [
        'addpath(\'%s\')' % (os.path.join(seism_root, 'src/piotr_edges/'))
    ]
    output_file += [
        'addpath(\'%s\')' % (os.path.join(seism_root, 'src/segbench/'))
    ]
    output_file.extend(seism_file[1:18])

    ## Add method (TODO)
    output_file += [
        'methods(end+1).name = \'%s\'; methods(end).io_func = @read_one_png; methods(end).legend =     methods(end).name;  methods(end).type = \'contour\';'
        % (exp_name)
    ]
    output_file.extend(seism_file[19:61])

    ## Add path to save output
    output_file += [
        'filename = \'%s\'' %
        (os.path.join(save_dir, database + '_' + 'test' + '_edge.txt'))
    ]
    output_file += seism_file[62:]

    # save the file to the seism dir
    output_file_path = os.path.join(seism_root, exp_name + '.m')
    with open(output_file_path, 'w') as f:
        for line in output_file:
            f.write(line + '\n')

    # go to the seism dir and perform evaluation
    print(
        'Go to seism root dir and run the evaluation ... This takes time ...')
    cwd = os.getcwd()
    os.chdir(seism_root)
    os.system(
        "matlab -nodisplay -nosplash -nodesktop -r \"addpath(\'%s\');%s;exit\""
        % (seism_root, exp_name))
    os.chdir(cwd)

    # write to json
    print('Finished evaluation in seism ... Write results to JSON ...')
    with open(os.path.join(save_dir, database + '_' + 'test' + '_edge.txt'),
              'r') as f:
        seism_result = [line.strip() for line in f.readlines()]

    eval_dict = {}
    for line in seism_result:
        metric, score = line.split(':')
        eval_dict[metric] = float(score)

    with open(os.path.join(save_dir, database + '_' + 'test' + '_edge.json'),
              'w') as f:
        json.dump(eval_dict, f)

    # print
    print('Edge Detection Evaluation')
    for k, v in eval_dict.items():
        spaces = ''
        for j in range(0, 10 - len(k)):
            spaces += ' '
        print('{0:s}{1:s}{2:.4f}'.format(k, spaces, 100 * v))

    # cleanup - Important. Else Matlab will reuse the files.
    print('Cleanup files in seism ...')
    result_rm = os.path.join(seism_root,
                             'results/%s/%s/' % (database, exp_name))
    data_rm = os.path.join(seism_root,
                           'datasets/%s/%s/' % (database, exp_name))
    os.system("rm -rf %s" % (result_rm))
    os.system("rm -rf %s" % (data_rm))
    print('Finished cleanup ...')

    return eval_dict