コード例 #1
0
ファイル: div2k.py プロジェクト: huawei-noah/vega
    def dataset_init(self):
        """Costruct method, which will load some dataset information."""
        self.args.root_HR = FileOps.download_dataset(self.args.root_HR)
        self.args.root_LR = FileOps.download_dataset(self.args.root_LR)
        if self.args.subfile is not None:
            with open(self.args.subfile
                      ) as f:  # lmdb format has no self.args.subfile
                file_names = sorted([line.rstrip('\n') for line in f])
                self.datatype = util.get_files_datatype(file_names)
                self.paths_HR = [
                    os.path.join(self.args.root_HR, file_name)
                    for file_name in file_names
                ]
                self.paths_LR = [
                    os.path.join(self.args.root_LR, file_name)
                    for file_name in file_names
                ]
        else:
            self.datatype = util.get_datatype(self.args.root_LR)
            self.paths_LR = util.get_paths_from_dir(self.args.root_LR)
            self.paths_HR = util.get_paths_from_dir(self.args.root_HR)

        if self.args.save_in_memory:
            self.imgs_LR = [self._read_img(path) for path in self.paths_LR]
            self.imgs_HR = [self._read_img(path) for path in self.paths_HR]
コード例 #2
0
    def dataset_init(self):
        """Construct method.

        If both data_dir and label_dir are provided, then use data_dir and label_dir
        Otherwise use data_path and list_file.
        """
        if "data_dir" in self.args and "label_dir" in self.args:
            self.args.data_dir = FileOps.download_dataset(self.args.data_dir)
            self.args.label_dir = FileOps.download_dataset(self.args.label_dir)
            self.data_files = sorted(glob.glob(osp.join(self.args.data_dir, "*")))
            self.label_files = sorted(glob.glob(osp.join(self.args.label_dir, "*")))
        else:
            if "data_path" not in self.args or "list_file" not in self.args:
                raise Exception("You must provide a data_path and a list_file!")
            self.args.data_path = FileOps.download_dataset(self.args.data_path)
            with open(osp.join(self.args.data_path, self.args.list_file)) as f:
                lines = f.readlines()
            self.data_files = [None] * len(lines)
            self.label_files = [None] * len(lines)
            for i, line in enumerate(lines):
                data_file_name, label_file_name = line.strip().split()
                self.data_files[i] = osp.join(self.args.data_path, data_file_name)
                self.label_files[i] = osp.join(self.args.data_path, label_file_name)

        datatype = self._get_datatype()
        if datatype == "image":
            self.read_fn = self._read_item_image
        else:
            self.read_fn = self._read_item_pickle
コード例 #3
0
    def __init__(self, **kwargs):
        """Construct the dataset."""
        super().__init__(**kwargs)
        self.args.data_path = FileOps.download_dataset(self.args.data_path)
        dataset_pairs = dict(train=create_train_subset(self.args.data_path),
                             test=create_test_subset(self.args.data_path),
                             val=create_test_subset(self.args.data_path))

        if self.mode not in dataset_pairs.keys():
            raise NotImplementedError(
                f'mode should be one of {dataset_pairs.keys()}')
        self.image_annot_path_pairs = dataset_pairs.get(self.mode)

        self.codec_obj = PointLaneCodec(input_width=512,
                                        input_height=288,
                                        anchor_stride=16,
                                        points_per_line=72,
                                        class_num=2)
        self.encode_lane = self.codec_obj.encode_lane
        read_funcs = dict(
            CULane=_read_culane_type_annot,
            CurveLane=_read_curvelane_type_annot,
        )
        if self.args.dataset_format not in read_funcs:
            raise NotImplementedError(
                f'dataset_format should be one of {read_funcs.keys()}')
        self.read_annot = read_funcs.get(self.args.dataset_format)
        self.with_aug = self.args.get('with_aug', False)
コード例 #4
0
    def __init__(self, **kwargs):
        """Construct the Cifar10 class."""
        Dataset.__init__(self, **kwargs)
        self.args.data_path = FileOps.download_dataset(self.args.data_path)
        is_train = self.mode == 'train' or self.mode == 'val' and self.args.train_portion < 1
        self.base_folder = 'cifar-100-python'
        if is_train:
            files_list = ["train"]
        else:
            files_list = ['test']

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name in files_list:
            file_path = os.path.join(self.args.data_path, self.base_folder,
                                     file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
コード例 #5
0
ファイル: mnist.py プロジェクト: huawei-noah/vega
 def __init__(self, **kwargs):
     """Construct the Mnist class."""
     Dataset.__init__(self, **kwargs)
     self.args.data_path = FileOps.download_dataset(self.args.data_path)
     MNIST.__init__(self,
                    root=self.args.data_path,
                    train=self.train,
                    transform=self.transforms,
                    download=self.args.download)
コード例 #6
0
ファイル: div2k_unpair.py プロジェクト: huawei-noah/vega
    def dataset_init(self):
        """Initialize dataset."""
        self.args.HR_dir = FileOps.download_dataset(self.args.HR_dir)
        self.args.LR_dir = FileOps.download_dataset(self.args.LR_dir)
        self.Y_paths = sorted(self.make_dataset(
            self.args.LR_dir,
            float("inf"))) if self.args.LR_dir is not None else None
        self.HR_paths = sorted(
            self.make_dataset(
                self.args.HR_dir,
                float("inf"))) if self.args.HR_dir is not None else None

        self.trans_norm = transforms.Compose(
            [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        for i in range(len(self.HR_paths)):
            file_name = os.path.basename(self.HR_paths[i])
            if (file_name.find("0401") >= 0):
                logging.info(
                    "We find the possion of NO. 401 in the HR patch NO. {}".
                    format(i))
                self.HR_paths = self.HR_paths[:i]
                break
        for i in range(len(self.Y_paths)):
            file_name = os.path.basename(self.Y_paths[i])
            if (file_name.find("0401") >= 0):
                logging.info(
                    "We find the possion of NO. 401 in the LR patch NO. {}".
                    format(i))
                self.Y_paths = self.Y_paths[i:]
                break

        self.Y_size = len(self.Y_paths)
        if self.train:
            self.load_size = self.args.load_size
            self.crop_size = self.args.crop_size
            self.upscale = self.args.upscale
            self.augment_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip()
            ])
            self.HR_transform = transforms.RandomCrop(
                int(self.crop_size * self.upscale))
            self.LR_transform = transforms.RandomCrop(self.crop_size)
コード例 #7
0
 def __init__(self, **kwargs):
     """Construct the Imagenet class."""
     Dataset.__init__(self, **kwargs)
     self.args.data_path = FileOps.download_dataset(self.args.data_path)
     split = 'train' if self.mode == 'train' else 'val'
     local_data_path = FileOps.join_path(self.args.data_path, split)
     delattr(self, 'loader')
     ImageFolder.__init__(self,
                          root=local_data_path,
                          transform=Compose(self.transforms.__transform__))
コード例 #8
0
ファイル: imagenet.py プロジェクト: huawei-noah/vega
 def __init__(self, **kwargs):
     """Init Cifar10."""
     super(Imagenet, self).__init__(**kwargs)
     self.data_path = FileOps.download_dataset(self.args.data_path)
     self.fp16 = self.args.fp16
     self.num_parallel_batches = self.args.num_parallel_batches
     self.image_size = self.args.image_size
     self.drop_remainder = self.args.drop_last
     if self.data_path == 'null' or not self.data_path:
         self.data_path = None
     self.num_parallel_calls = self.args.num_parallel_calls
コード例 #9
0
ファイル: cls_ds.py プロジェクト: huawei-noah/vega
 def __init__(self, **kwargs):
     """Construct the classification class."""
     Dataset.__init__(self, **kwargs)
     self.args.data_path = FileOps.download_dataset(self.args.data_path)
     sub_path = os.path.abspath(os.path.join(self.args.data_path,
                                             self.mode))
     if self.args.train_portion != 1.0 and self.mode == "val":
         sub_path = os.path.abspath(
             os.path.join(self.args.data_path, "train"))
     if self.args.train_portion == 1.0 and self.mode == "val" and not os.path.exists(
             sub_path):
         sub_path = os.path.abspath(
             os.path.join(self.args.data_path, "test"))
     if not os.path.exists(sub_path):
         raise ("dataset path is not existed, path={}".format(sub_path))
     self._load_file_indexes(sub_path)
     self._load_data()
     self._shuffle()
コード例 #10
0
ファイル: avazu.py プロジェクト: huawei-noah/vega
 def __init__(self, **kwargs):
     """Construct the AvazuDataset class."""
     super(AvazuDataset, self).__init__(**kwargs)
     self.args.data_path = FileOps.download_dataset(self.args.data_path)
     logging.info("init new avazu_dataset finish. 0721 debug.")