def dataset_init(self): """Costruct method, which will load some dataset information.""" self.args.root_HR = FileOps.download_dataset(self.args.root_HR) self.args.root_LR = FileOps.download_dataset(self.args.root_LR) if self.args.subfile is not None: with open(self.args.subfile ) as f: # lmdb format has no self.args.subfile file_names = sorted([line.rstrip('\n') for line in f]) self.datatype = util.get_files_datatype(file_names) self.paths_HR = [ os.path.join(self.args.root_HR, file_name) for file_name in file_names ] self.paths_LR = [ os.path.join(self.args.root_LR, file_name) for file_name in file_names ] else: self.datatype = util.get_datatype(self.args.root_LR) self.paths_LR = util.get_paths_from_dir(self.args.root_LR) self.paths_HR = util.get_paths_from_dir(self.args.root_HR) if self.args.save_in_memory: self.imgs_LR = [self._read_img(path) for path in self.paths_LR] self.imgs_HR = [self._read_img(path) for path in self.paths_HR]
def dataset_init(self): """Construct method. If both data_dir and label_dir are provided, then use data_dir and label_dir Otherwise use data_path and list_file. """ if "data_dir" in self.args and "label_dir" in self.args: self.args.data_dir = FileOps.download_dataset(self.args.data_dir) self.args.label_dir = FileOps.download_dataset(self.args.label_dir) self.data_files = sorted(glob.glob(osp.join(self.args.data_dir, "*"))) self.label_files = sorted(glob.glob(osp.join(self.args.label_dir, "*"))) else: if "data_path" not in self.args or "list_file" not in self.args: raise Exception("You must provide a data_path and a list_file!") self.args.data_path = FileOps.download_dataset(self.args.data_path) with open(osp.join(self.args.data_path, self.args.list_file)) as f: lines = f.readlines() self.data_files = [None] * len(lines) self.label_files = [None] * len(lines) for i, line in enumerate(lines): data_file_name, label_file_name = line.strip().split() self.data_files[i] = osp.join(self.args.data_path, data_file_name) self.label_files[i] = osp.join(self.args.data_path, label_file_name) datatype = self._get_datatype() if datatype == "image": self.read_fn = self._read_item_image else: self.read_fn = self._read_item_pickle
def __init__(self, **kwargs): """Construct the dataset.""" super().__init__(**kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) dataset_pairs = dict(train=create_train_subset(self.args.data_path), test=create_test_subset(self.args.data_path), val=create_test_subset(self.args.data_path)) if self.mode not in dataset_pairs.keys(): raise NotImplementedError( f'mode should be one of {dataset_pairs.keys()}') self.image_annot_path_pairs = dataset_pairs.get(self.mode) self.codec_obj = PointLaneCodec(input_width=512, input_height=288, anchor_stride=16, points_per_line=72, class_num=2) self.encode_lane = self.codec_obj.encode_lane read_funcs = dict( CULane=_read_culane_type_annot, CurveLane=_read_curvelane_type_annot, ) if self.args.dataset_format not in read_funcs: raise NotImplementedError( f'dataset_format should be one of {read_funcs.keys()}') self.read_annot = read_funcs.get(self.args.dataset_format) self.with_aug = self.args.get('with_aug', False)
def __init__(self, **kwargs): """Construct the Cifar10 class.""" Dataset.__init__(self, **kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) is_train = self.mode == 'train' or self.mode == 'val' and self.args.train_portion < 1 self.base_folder = 'cifar-100-python' if is_train: files_list = ["train"] else: files_list = ['test'] self.data = [] self.targets = [] # now load the picked numpy arrays for file_name in files_list: file_path = os.path.join(self.args.data_path, self.base_folder, file_name) with open(file_path, 'rb') as f: entry = pickle.load(f, encoding='latin1') self.data.append(entry['data']) if 'labels' in entry: self.targets.extend(entry['labels']) else: self.targets.extend(entry['fine_labels']) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
def __init__(self, **kwargs): """Construct the Mnist class.""" Dataset.__init__(self, **kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) MNIST.__init__(self, root=self.args.data_path, train=self.train, transform=self.transforms, download=self.args.download)
def dataset_init(self): """Initialize dataset.""" self.args.HR_dir = FileOps.download_dataset(self.args.HR_dir) self.args.LR_dir = FileOps.download_dataset(self.args.LR_dir) self.Y_paths = sorted(self.make_dataset( self.args.LR_dir, float("inf"))) if self.args.LR_dir is not None else None self.HR_paths = sorted( self.make_dataset( self.args.HR_dir, float("inf"))) if self.args.HR_dir is not None else None self.trans_norm = transforms.Compose( [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) for i in range(len(self.HR_paths)): file_name = os.path.basename(self.HR_paths[i]) if (file_name.find("0401") >= 0): logging.info( "We find the possion of NO. 401 in the HR patch NO. {}". format(i)) self.HR_paths = self.HR_paths[:i] break for i in range(len(self.Y_paths)): file_name = os.path.basename(self.Y_paths[i]) if (file_name.find("0401") >= 0): logging.info( "We find the possion of NO. 401 in the LR patch NO. {}". format(i)) self.Y_paths = self.Y_paths[i:] break self.Y_size = len(self.Y_paths) if self.train: self.load_size = self.args.load_size self.crop_size = self.args.crop_size self.upscale = self.args.upscale self.augment_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip() ]) self.HR_transform = transforms.RandomCrop( int(self.crop_size * self.upscale)) self.LR_transform = transforms.RandomCrop(self.crop_size)
def __init__(self, **kwargs): """Construct the Imagenet class.""" Dataset.__init__(self, **kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) split = 'train' if self.mode == 'train' else 'val' local_data_path = FileOps.join_path(self.args.data_path, split) delattr(self, 'loader') ImageFolder.__init__(self, root=local_data_path, transform=Compose(self.transforms.__transform__))
def __init__(self, **kwargs): """Init Cifar10.""" super(Imagenet, self).__init__(**kwargs) self.data_path = FileOps.download_dataset(self.args.data_path) self.fp16 = self.args.fp16 self.num_parallel_batches = self.args.num_parallel_batches self.image_size = self.args.image_size self.drop_remainder = self.args.drop_last if self.data_path == 'null' or not self.data_path: self.data_path = None self.num_parallel_calls = self.args.num_parallel_calls
def __init__(self, **kwargs): """Construct the classification class.""" Dataset.__init__(self, **kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) sub_path = os.path.abspath(os.path.join(self.args.data_path, self.mode)) if self.args.train_portion != 1.0 and self.mode == "val": sub_path = os.path.abspath( os.path.join(self.args.data_path, "train")) if self.args.train_portion == 1.0 and self.mode == "val" and not os.path.exists( sub_path): sub_path = os.path.abspath( os.path.join(self.args.data_path, "test")) if not os.path.exists(sub_path): raise ("dataset path is not existed, path={}".format(sub_path)) self._load_file_indexes(sub_path) self._load_data() self._shuffle()
def __init__(self, **kwargs): """Construct the AvazuDataset class.""" super(AvazuDataset, self).__init__(**kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) logging.info("init new avazu_dataset finish. 0721 debug.")