예제 #1
0
def celebaSetup(inputPath,
                outputPath,
                pathConfig="config_celeba_cropped.json"):

    imgList = [f for f in os.listdir(
        inputPath) if os.path.splitext(f)[1] == ".jpg"]
    cx = 89
    cy = 121

    nImgs = len(imgList)

    if not os.path.isdir(outputPath):
        os.mkdir(outputPath)

    for index, item in enumerate(imgList):
        printProgressBar(index, nImgs)
        path = os.path.join(inputPath, item)
        img = np.array(pil_loader(path))

        img = img[cy - 64: cy + 64, cx - 64: cx + 64]

        path = os.path.join(outputPath, item)
        saveImage(path, img)

    printProgressBar(nImgs, nImgs)
예제 #2
0
def unpackCIFAR10(pathDB, pathOut):
    toLoad = [
        'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4',
        'data_batch_5'
    ]

    # Check valididty:
    for item in toLoad:
        filePath = os.path.join(pathDB, item)
        if not os.path.isfile(filePath):
            raise FileNotFoundError("Can't find " + filePath)

    if not os.path.isdir(pathOut):
        os.mkdir(pathOut)

    pathLabels = os.path.join(pathDB, 'batches.meta')
    with open(pathLabels, 'rb') as file:
        labels = pickle.load(file)['label_names']

    for label in labels:
        pathOutLabel = os.path.join(pathOut, label)
        if not os.path.isdir(pathOutLabel):
            os.mkdir(pathOutLabel)

    nImagesPerBatch = 10000
    nImages = nImagesPerBatch * len(toLoad)
    status = 0

    print("Unpacking CIFAR-10...")
    for item in toLoad:
        pathItem = os.path.join(pathDB, item)
        with open(pathItem, 'rb') as file:
            dict = pickle.load(file, encoding='bytes')
        data = dict[str.encode('data')]
        dataLabel = dict[str.encode('labels')]
        dataNames = dict[str.encode('filenames')]

        assert (len(dataLabel) == nImagesPerBatch)
        assert (data.shape[1] == 3072)

        for i in range(nImagesPerBatch):

            rgbArray = np.zeros((32, 32, 3), 'uint8')
            rgbArray[:, :, 0] = data[i, :1024].reshape(32, 32)
            rgbArray[:, :, 1] = data[i, 1024:2048].reshape(32, 32)
            rgbArray[:, :, 2] = data[i, 2048:].reshape(32, 32)

            name = dataNames[i].decode("utf-8")
            label = labels[dataLabel[i]]

            path = os.path.join(pathOut, os.path.join(label, name))
            saveImage(path, rgbArray)

            printProgressBar(status, nImages)
            status += 1

    printProgressBar(nImages, nImages)
예제 #3
0
def resizeDataset(inputPath, outputPath, maxSize):

    sizes = [64, 128, 512, 1024]
    scales = [0, 5, 6, 8]
    index = 0

    imgList = [
        f for f in os.listdir(inputPath)
        if os.path.splitext(f)[1] in [".jpg", ".npy"]
    ]

    nImgs = len(imgList)

    if maxSize < sizes[0]:
        raise AttributeError("Maximum resolution too low")

    if not os.path.isdir(outputPath):
        os.mkdir(outputPath)

    datasetProfile = {}

    for index, size in enumerate(sizes):

        if size > maxSize:
            break

        localPath = os.path.join(outputPath, str(size))
        if not os.path.isdir(localPath):
            os.mkdir(localPath)

        datasetProfile[str(scales[index])] = localPath

        print("Resolution %d x %d" % (size, size))

        resizeModule = NumpyResize((size, size))

        for index, item in enumerate(imgList):
            printProgressBar(index, nImgs)
            path = os.path.join(inputPath, item)
            img = pil_loader(path)

            img = resizeModule(img)
            path = os.path.splitext(os.path.join(localPath, item))[0] + ".jpg"
            saveImage(path, img)
        printProgressBar(nImgs, nImgs)

    return datasetProfile, localPath
예제 #4
0
def fashionGenSetup(fashionGenPath,
                    outputPath):

    basePath = os.path.splitext(outputPath)[0]

    if not os.path.isdir(basePath):
        os.mkdir(basePath)

    outputPath = os.path.join(basePath, os.path.basename(basePath))

    h5file = h5py.File(fashionGenPath)
    imgKey = 'input_image'
    validClasses = ["input_gender", "input_category", "input_pose"]
    nImgs = h5file[imgKey].shape[0]

    outIndexes = {}
    statsPartition = {"GLOBAL": {"input_department": {},
                                 "totalSize": 0}
                      }

    for attribute in validClasses:
        statsPartition["GLOBAL"][attribute] = {}

    partitionCategory = "input_department"

    print("Building the partition..")

    for index in range(nImgs):

        rawVal = str(h5file[partitionCategory][index][0])
        val = rawVal.replace("b'", "").replace("'", "")
        strVal = str(val)

        # Hand-made fix for the clothing dataset : some pose attributes
        # correspond only to miss-labelled data
        if strVal == "CLOTHING" \
                and str(h5file["input_pose"][index][0]) in \
                ["b'id_gridfs_6'", "b'id_gridfs_5'"]:
            continue

        if strVal not in statsPartition:
            statsPartition[strVal] = {attribute: {}
                                      for attribute in validClasses}
            statsPartition[strVal]["totalSize"] = 0
            outIndexes[val] = []
            statsPartition["GLOBAL"]["input_department"][rawVal] = 0

        outIndexes[val].append(index)
        statsPartition[strVal]["totalSize"] += 1
        statsPartition["GLOBAL"]["input_department"][rawVal] += 1
        statsPartition["GLOBAL"]["totalSize"] += 1

        for attribute in validClasses:

            label = str(h5file[attribute][index][0])

            if label not in statsPartition[strVal][attribute]:
                statsPartition[strVal][attribute][label] = 0
            if label not in statsPartition["GLOBAL"][attribute]:
                statsPartition["GLOBAL"][attribute][label] = 0

            statsPartition[strVal][attribute][label] += 1
            statsPartition["GLOBAL"][attribute][label] += 1

        printProgressBar(index, nImgs)
    printProgressBar(nImgs, nImgs)

    h5file.close()

    pathPartition = outputPath + "_partition.h5"
    f = h5py.File(pathPartition, 'w')

    for key, value in outIndexes.items():
        f.create_dataset(key, data=np.array(value))
    f.close()

    pathStats = outputPath + "_stats.json"
    with open(pathStats, 'w') as file:
        json.dump(statsPartition, file, indent=2)

    return pathPartition, pathStats