예제 #1
0
def get_dataset(prefetch=False):
    image_path = os.path.join(dataset_path, "BSDS300/images")

    if not os.path.exists(image_path):
        os.makedirs(dataset_path)
        file_name = download(dataset_url)
        with tarfile.open(file_name) as tar:
            for item in tar:
                tar.extract(item, dataset_path)
        os.remove(file_name)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [
        CenterCropAug((crop_size, crop_size)),
        ResizeAug(input_crop_size)
    ]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (ImagePairIter(os.path.join(image_path, "train"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size), batch_size, color_flag,
                           input_transform, target_transform),
             ImagePairIter(os.path.join(image_path, "test"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size), test_batch_size, color_flag,
                           input_transform, target_transform))

    return [PrefetchingIter(i) for i in iters] if prefetch else iters
예제 #2
0
def get_dataset(datadir):
    filelistpath = os.path.join(datadir, "filelist.txt")
    valid_files = []
    train_files = []
    for i, l in enumerate(open(filelistpath)):
        l = l.strip()
        if i % 10 == 0:
            valid_files.append(os.path.join(datadir, l))
        else:
            train_files.append(os.path.join(datadir, l))

    crop_size = 256
    input_crop_size = crop_size // upscale_factor

    common_transform = [
        RandomCropAug((crop_size, crop_size)),
    ]
    input_transform = [ResizeAug(input_crop_size)]
    target_transform = []

    iters = (ImagePairIter(train_files, (input_crop_size, input_crop_size),
                           (crop_size, crop_size), batch_size,
                           common_transform, input_transform,
                           target_transform),
             ImagePairIter(valid_files, (input_crop_size, input_crop_size),
                           (crop_size, crop_size), test_batch_size,
                           common_transform, input_transform,
                           target_transform))

    print("#Train={}, #Valid={}".format(len(train_files), len(valid_files)))

    return iters
예제 #3
0
def get_dataset(prefetch=False):
    """Download the BSDS500 dataset and return train and test iters."""

    if path.exists(data_dir):
        print(
            "Directory {} already exists, skipping.\n"
            "To force download and extraction, delete the directory and re-run."
            "".format(data_dir),
            file=sys.stderr,
        )
    else:
        print("Downloading dataset...", file=sys.stderr)
        downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
        print("done", file=sys.stderr)

        print("Extracting files...", end="", file=sys.stderr)
        os.makedirs(data_dir)
        os.makedirs(tmp_dir)
        with zipfile.ZipFile(downloaded_file) as archive:
            archive.extractall(tmp_dir)
        shutil.rmtree(datasets_tmpdir)

        shutil.copytree(
            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
            path.join(data_dir, "images"),
        )
        shutil.copytree(
            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data",
                      "groundTruth"),
            path.join(data_dir, "groundTruth"),
        )
        shutil.rmtree(tmp_dir)
        print("done", file=sys.stderr)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [
        CenterCropAug((crop_size, crop_size)),
        ResizeAug(input_crop_size)
    ]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (
        ImagePairIter(
            path.join(data_dir, "images", "train"),
            (input_crop_size, input_crop_size),
            (crop_size, crop_size),
            batch_size,
            color_flag,
            input_transform,
            target_transform,
        ),
        ImagePairIter(
            path.join(data_dir, "images", "test"),
            (input_crop_size, input_crop_size),
            (crop_size, crop_size),
            test_batch_size,
            color_flag,
            input_transform,
            target_transform,
        ),
    )

    return [PrefetchingIter(i) for i in iters] if prefetch else iters