예제 #1
0
    def _get_dataloaders(self, num_workers, shuffle_train=True):
        assert self.config_dl.train_imgs_glob is not None
        print('Cropping to {}'.format(self.config_dl.crop_size))
        to_tensor_transform = transforms.Compose(
                [transforms.RandomCrop(self.config_dl.crop_size),
                 transforms.RandomHorizontalFlip(),
                 images_loader.IndexImagesDataset.to_tensor_uint8_transform()])
        # NOTE: if there are images in your training set with dimensions <128, training will abort at some point,
        # because the cropper failes. See REAME, section about data preparation.
        min_size = self.config_dl.crop_size
        if min_size <= 128:
            min_size = None
        ds_train = images_loader.IndexImagesDataset(
                images=images_loader.ImagesCached(
                        self.config_dl.train_imgs_glob,
                        self.config_dl.image_cache_pkl,
                        min_size=min_size),
                to_tensor_transform=to_tensor_transform)

        dl_train = DataLoader(ds_train, self.config_dl.batchsize_train, shuffle=shuffle_train,
                              num_workers=num_workers)
        print('Created DataLoader [train] {} batches -> {} imgs'.format(
                len(dl_train), self.config_dl.batchsize_train * len(dl_train)))

        ds_val = self._get_ds_val(
                self.config_dl.val_glob,
                crop=self.config_dl.crop_size,
                truncate=self.config_dl.num_val_batches * self.config_dl.batchsize_val)
        dl_val = DataLoader(
                ds_val, self.config_dl.batchsize_val, shuffle=False,
                num_workers=num_workers, drop_last=True)
        print('Created DataLoader [val] {} batches -> {} imgs'.format(
                len(dl_val), self.config_dl.batchsize_train * len(dl_val)))

        return dl_train, dl_val
예제 #2
0
    def _get_ds_val(self, images_spec, crop=False, truncate=False):
        img_to_tensor_t = [
            images_loader.IndexImagesDataset.to_tensor_uint8_transform()
        ]
        if crop:
            img_to_tensor_t.insert(0, transforms.CenterCrop(crop))
        img_to_tensor_t = transforms.Compose(img_to_tensor_t)

        fixed_first = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                   'fixedimg.jpg')
        if not os.path.isfile(fixed_first):
            print(f'INFO: No file found at {fixed_first}')
            fixed_first = None

        ds = images_loader.IndexImagesDataset(
            images=images_loader.ImagesCached(
                images_spec,
                self.config_dl.image_cache_pkl,
                min_size=self.config_dl.val_glob_min_size),
            to_tensor_transform=img_to_tensor_t,
            fixed_first=fixed_first
        )  # fix a first image to have consistency in tensor board

        if truncate:
            ds = pe.TruncatedDataset(ds, num_elemens=truncate)

        return ds