Exemple #1
0
    def populateFeedGraph(self, shuffle=True, cutPaperOut=False):
        with tf.name_scope("load_images"):
            #Create a tensor out of the list of paths
            filenamesTensor = tf.constant(self.pathList)
            #Reads a slice of the tensor, for example, if the tensor is of shape [100,2], the slice shape should be [2] (to check if we have problem here)
            dataset = tf.data.Dataset.from_tensor_slices(filenamesTensor)

            #for each slice apply the __readImages function
            dataset = dataset.map(self.__readImages,
                                  num_parallel_calls=int(
                                      multiprocessing.cpu_count() / 4))
            #Authorize repetition of the dataset when one epoch is over.
            if shuffle:
                dataset = dataset.shuffle(buffer_size=16,
                                          reshuffle_each_iteration=True)
            #set batch size
            dataset = dataset.repeat()
            batched_dataset = dataset.batch(self.batchSize)
            batched_dataset = batched_dataset.prefetch(buffer_size=4)
            #Create an iterator to be initialized
            iterator = batched_dataset.make_initializable_iterator()

            #Create the node to retrieve next batch
            paths_batch, inputs_batch, targets_batch, gammadInputBatch = iterator.get_next(
            )

            self.gammaCorrectedInputsBatch = gammadInputBatch
            reshaped_targets = helpers.target_reshape(targets_batch)
            #inputRealSize = self.tileSize
            inputRealSize = self.inputImageSize

            #Do the random crop, if the crop if fix, crop in the middle
            if inputRealSize > self.tileSize:
                if self.fixCrop:
                    xyCropping = (inputRealSize - self.tileSize) // 2
                    xyCropping = [xyCropping, xyCropping]
                else:
                    xyCropping = tf.random_uniform([1],
                                                   0,
                                                   inputRealSize -
                                                   self.tileSize,
                                                   dtype=tf.int32)

                inputs_batch = inputs_batch[:, :, xyCropping[0]:xyCropping[0] +
                                            self.tileSize,
                                            xyCropping[0]:xyCropping[0] +
                                            self.tileSize, :]
                targets_batch = targets_batch[:, :,
                                              xyCropping[0]:xyCropping[0] +
                                              self.tileSize,
                                              xyCropping[0]:xyCropping[0] +
                                              self.tileSize, :]

            #Set shapes
            inputs_batch.set_shape([None, self.tileSize, self.tileSize, 3])
            targets_batch.set_shape(
                [None, self.nbTargetsToRead, self.tileSize, self.tileSize, 3])

            #Populate the object
            self.stepsPerEpoch = int(
                math.floor(len(self.pathList) / self.batchSize))
            self.inputBatch = inputs_batch
            self.targetBatch = targets_batch
            self.iterator = iterator
            self.pathBatch = paths_batch
Exemple #2
0
    def __renderInputs(self, materials, renderingScene, jitterLightPos,
                       jitterViewPos, mixMaterials):
        fullSizeMixedMaterial = materials
        if mixMaterials:
            alpha = tf.random_uniform([1],
                                      minval=0.1,
                                      maxval=0.9,
                                      dtype=tf.float32,
                                      name="mixAlpha")

            materials1 = materials[::2]
            materials2 = materials[1::2]

            fullSizeMixedMaterial = helpers.mixMaterials(
                materials1, materials2, alpha)

        if self.inputImageSize >= self.tileSize:
            if self.fixCrop:
                xyCropping = (self.inputImageSize - self.tileSize) // 2
                xyCropping = [xyCropping, xyCropping]
            else:
                xyCropping = tf.random_uniform([2],
                                               0,
                                               self.inputImageSize -
                                               self.tileSize,
                                               dtype=tf.int32)
            cropped_mixedMaterial = fullSizeMixedMaterial[:, :, xyCropping[
                0]:xyCropping[0] + self.tileSize, xyCropping[1]:xyCropping[1] +
                                                          self.tileSize, :]
        elif self.inputImageSize < self.tileSize:
            raise Exception(
                "Size of the input is inferior to the size of the rendering, please provide higher resolution maps"
            )
        cropped_mixedMaterial.set_shape(
            [None, self.nbTargetsToRead, self.tileSize, self.tileSize, 3])
        mixedMaterial = helpers.adaptRougness(cropped_mixedMaterial)

        targetstoRender = helpers.target_reshape(
            mixedMaterial
        )  #reshape it to be compatible with the rendering algorithm [?, size, size, 12]
        nbRenderings = 1
        rendererInstance = renderer.GGXRenderer(includeDiffuse=True)
        ## Do renderings of the mixedMaterial

        targetstoRender = helpers.preprocess(
            targetstoRender)  #Put targets to -1; 1
        surfaceArray = helpers.generateSurfaceArray(self.tileSize)

        inputs = helpers.generateInputRenderings(
            rendererInstance,
            targetstoRender,
            self.batchSize,
            nbRenderings,
            surfaceArray,
            renderingScene,
            jitterLightPos,
            jitterViewPos,
            self.useAmbientLight,
            useAugmentationInRenderings=self.useAugmentationInRenderings)

        self.gammaCorrectedInputsBatch = tf.squeeze(inputs, [1])

        inputs = tf.pow(inputs, 2.2)  # correct gamma
        if self.logInput:
            inputs = helpers.logTensor(inputs)

        inputs = helpers.preprocess(inputs)  #Put inputs to -1; 1

        targets = helpers.target_deshape(targetstoRender, self.nbTargetsToRead)
        return targets, inputs
def main():
    if a.seed is None:
        a.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(a.seed)
    np.random.seed(a.seed)
    random.seed(a.seed)
    loadCheckpointOption(a.mode, a.checkpoint) #loads so that I don't mix up options and it generates data corresponding to this training

    config = tf.ConfigProto()

    if not os.path.exists(a.output_dir):
        os.makedirs(a.output_dir)

    with open(os.path.join(a.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(a), sort_keys=True, indent=4))

    data = dataReader.dataset(a.input_dir, imageFormat = a.imageFormat, trainFolder = a.trainFolder, testFolder = a.testFolder, nbTargetsToRead = a.nbTargets, tileSize=TILE_SIZE, inputImageSize=a.input_size, batchSize=a.batch_size, fixCrop = (a.mode == "test"), mixMaterials = (a.mode == "train" or a.mode == "finetune"), logInput = a.useLog, useAmbientLight = a.useAmbientLight, useAugmentationInRenderings = not a.NoAugmentationInRenderings)
    # Populate data
    data.loadPathList(a.inputMode, a.mode, a.mode == "train" or a.mode == "finetune", inputpythonList)

    if a.feedMethod == "render":
        if a.mode == "train":
            data.populateInNetworkFeedGraph(a.renderingScene, a.jitterLightPos, a.jitterViewPos,  shuffle = (a.mode == "train"  or a.mode == "finetune"))
        elif a.mode == "finetune":
            data.populateInNetworkFeedGraphSpatialMix(a.renderingScene, shuffle = False, imageSize = a.input_size)

    elif a.feedMethod == "files":
        data.populateFeedGraph(shuffle = (a.mode == "train"  or a.mode == "finetune"))


    if a.mode == "train" or a.mode == "finetune":
        with tf.name_scope("recurrentTest"):
            dataTest = dataReader.dataset(a.input_dir, imageFormat = a.imageFormat, testFolder = a.testFolder, nbTargetsToRead = a.nbTargets, tileSize=TILE_SIZE, inputImageSize=a.test_input_size, batchSize=a.batch_size, fixCrop = True, mixMaterials = False, logInput = a.useLog, useAmbientLight = a.useAmbientLight, useAugmentationInRenderings = not a.NoAugmentationInRenderings)
            dataTest.loadPathList(a.inputMode, "test", False, inputpythonList)
            if a.testApproach == "render":
                #dataTest.populateInNetworkFeedGraphSpatialMix(a.renderingScene, shuffle = False, imageSize = TILE_SIZE, useSpatialMix=False)
                dataTest.populateInNetworkFeedGraph(a.renderingScene, a.jitterLightPos, a.jitterViewPos, shuffle = False)
            elif a.testApproach == "files":
                dataTest.populateFeedGraph(False) 

    targetsReshaped = helpers.target_reshape(data.targetBatch)

    #CreateModel
    model = mod.Model(data.inputBatch, generatorOutputChannels=9)
    model.create_model()
    if a.mode == "train" or a.mode == "finetune":
        testTargetsReshaped = helpers.target_reshape(dataTest.targetBatch)

        testmodel = mod.Model(dataTest.inputBatch, generatorOutputChannels=9, reuse_bool=True)

        testmodel.create_model()
        display_fetches_test, _ = helpers.display_images_fetches(dataTest.pathBatch, dataTest.inputBatch, dataTest.targetBatch, dataTest.gammaCorrectedInputsBatch, testmodel.output, a.nbTargets, a.logOutputAlbedos)

        loss = losses.Loss(a.loss, model.output, targetsReshaped, TILE_SIZE, a.batch_size, tf.placeholder(tf.float64, shape=(), name="lr"), a.includeDiffuse, a.nbSpecularRendering, a.nbDiffuseRendering)

        loss.createLossGraph()
        loss.createTrainVariablesGraph()

    #Register Renderings And Loss In Tensorflow
    display_fetches, converted_images = helpers.display_images_fetches(data.pathBatch, data.inputBatch, data.targetBatch, data.gammaCorrectedInputsBatch, model.output, a.nbTargets, a.logOutputAlbedos)
    if a.mode == "train":
        helpers.registerTensorboard(data.pathBatch, converted_images, a.nbTargets, loss.lossValue, a.batch_size, loss.targetsRenderings, loss.outputsRenderings)

    #Run either training or test
    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
    saver = tf.train.Saver(max_to_keep=1)
    
    if a.checkpoint is not None:
        print("reading model from checkpoint : " + a.checkpoint)
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        partialSaver = helpers.optimistic_saver(checkpoint) #Be careful this will silently not load variables if they are missing from the graph or checkpoint
        
    logdir = a.output_dir if a.summary_freq > 0 else None
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)

    with sv.managed_session("", config= config) as sess:
        sess.run(data.iterator.initializer)
        print("parameter_count =", sess.run(parameter_count))

        if a.checkpoint is not None:
            print("restoring model from checkpoint : " + a.checkpoint)
            partialSaver.restore(sess, checkpoint)

        max_steps = 2**32
        if a.max_epochs is not None:
            max_steps = data.stepsPerEpoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        sess.run(data.iterator.initializer)
        if a.mode == "test":
            filesets = test(sess, data, max_steps, display_fetches, output_dir = a.output_dir)

        if a.mode == "train"  or a.mode == "finetune":
           train(sv, sess, data, max_steps, display_fetches, display_fetches_test, dataTest, saver, loss, a.output_dir)
Exemple #4
0
def main():

    if a.seed is None:
        a.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(a.seed)
    np.random.seed(a.seed)
    random.seed(a.seed)
    #Load some options from the checkpoint if we provided one.
    loadCheckpointOption()
    #If we feed the network with renderings done in the network for a test run, we save the images before, to be able to compare later with other networks on the same testset.
    if a.mode == "test" and a.feedMethod == "render":
        testHelpers.renderTests(a.input_dir, a.testFolder, a.maxImages, tmpFolder, a.imageFormat, CROP_SIZE, a.nbTargets, a.input_size, a.batch_size, a.renderingScene, a.jitterLightPos, a.jitterViewPos, a.inputMode, a.mode, a.output_dir)
        generateTmpData = True
        a.nbInputs = a.maxImages
        a.feedMethod = "files"
        a.testFolder = tmpFolder
        a.input_size = CROP_SIZE

    backupOutputDir = a.output_dir
    #We run the network once if we a training
    nbRun = 1
    #And as many time as the maximum number of images we want to treat with if testing (to have results with one image, two images, three images etc... to see the improvement)
    if a.mode == "test":
        nbRun = a.maxImages #1
        a.fixImageNb = True
        
    #Now run the network nbRun times.
    for runID in range(nbRun):
        maxInputNb = a.maxImages
        if a.mode == "test":
            maxInputNb = runID + 1 #a.maxImages
            a.output_dir = os.path.join(backupOutputDir, str(runID))
            tf.reset_default_graph()
        
        #Create the output dir if it doesn't exist
        if not os.path.exists(a.output_dir):
            os.makedirs(a.output_dir)

        #Write to the "options" file the different parameters of this run.
        with open(os.path.join(a.output_dir, "options.json"), "w") as f:
            f.write(json.dumps(vars(a), sort_keys=True, indent=4))

        #Create a dataset object
        data = dataReader.dataset(a.input_dir, imageType = a.imageFormat, trainFolder = a.trainFolder, testFolder = a.testFolder, inputNumbers = a.nbInputs, maxInputToRead = maxInputNb, nbTargetsToRead = a.nbTargets, cropSize=CROP_SIZE, inputImageSize=a.input_size, batchSize=a.batch_size, fixCrop = (a.mode == "test"), mixMaterials = (a.mode == "train"), fixImageNb = a.fixImageNb, logInput = a.useLog, useAmbientLight = a.useAmbientLight, jitterRenderings = a.jitterRenderings, firstAsGuide = False, useAugmentationInRenderings = not a.NoAugmentationInRenderings, mode = a.mode)

        # Populate the list of files the dataset will contain
        data.loadPathList(a.inputMode, a.mode, a.mode == "train")
        
        # Depending on wheter we want to render our input data or directly use files, we create the tensorflow data loading system.        
        if a.feedMethod == "render":
            data.populateInNetworkFeedGraph(a.renderingScene, a.jitterLightPos, a.jitterViewPos, a.mode == "test",  shuffle = a.mode == "train")
        elif a.feedMethod == "files":
            data.populateFeedGraph(shuffle = a.mode == "train")
        
        # Here we reshape the input to have all the images in the first dimension (to treat in parallel)
        inputReshaped, dyn_batch_size = helpers.input_reshape(data.inputBatch, a.NoMaxPooling, a.maxImages)
        
        if a.mode == "train":
            with tf.name_scope("recurrentTest"):
                #Initialize different data for tests.
                dataTest = dataReader.dataset(a.input_dir, imageType = a.imageFormat, testFolder = a.testFolder, inputNumbers = a.nbInputs, maxInputToRead = a.maxImages, nbTargetsToRead = a.nbTargets, cropSize=CROP_SIZE, inputImageSize=a.input_size, batchSize=a.batch_size, fixCrop = True, mixMaterials = False, fixImageNb = a.fixImageNb, logInput = a.useLog, useAmbientLight = a.useAmbientLight, jitterRenderings = a.jitterRenderings, firstAsGuide = a.firstAsGuide, useAugmentationInRenderings = not a.NoAugmentationInRenderings, mode = a.mode)
                dataTest.loadPathList(a.inputMode, "test", False)
                if a.feedMethod == "render":
                    dataTest.populateInNetworkFeedGraph(a.renderingScene, a.jitterLightPos, a.jitterViewPos, True, shuffle = False)
                elif a.feedMethod == "files":
                    dataTest.populateFeedGraph(False)
                TestinputReshaped, test_dyn_batch_size = helpers.input_reshape(dataTest.inputBatch, a.NoMaxPooling, a.maxImages)
                
        #Reshape the targets to [?(Batchsize), 256,256,12]
        targetsReshaped = helpers.target_reshape(data.targetBatch)

        #Create the object to contain the network model.
        model = mod.Model(inputReshaped, dyn_batch_size, last_convolutions_channels = last_convs_chans, generatorOutputChannels=64, useCoordConv = a.useCoordConv, firstAsGuide = a.firstAsGuide, NoMaxPooling = a.NoMaxPooling, pooling_type=a.poolingtype)
        
        #Initialize the model.
        model.create_model()

        if a.mode == "train":
            #Initialize the regular test network with different data so that it can run regular test sets.
            testTargetsReshaped = helpers.target_reshape(dataTest.targetBatch)
            testmodel = mod.Model(TestinputReshaped, test_dyn_batch_size, last_convolutions_channels = last_convs_chans, generatorOutputChannels=64, reuse_bool=True, useCoordConv = a.useCoordConv, firstAsGuide = a.firstAsGuide, NoMaxPooling = a.NoMaxPooling, pooling_type=a.poolingtype)
            testmodel.create_model()
            
            #Organize the images we want to retrieve from the test network run            
            display_fetches_test, _ = helpers.display_images_fetches(dataTest.pathBatch, dataTest.inputBatch, dataTest.targetBatch, dataTest.gammaCorrectedInputsBatch, testmodel.output, a.nbTargets, a.logOutputAlbedos)
            
            # Compute the training network loss.
            loss = losses.Loss(a.loss, model.output, targetsReshaped, CROP_SIZE, a.batch_size, tf.placeholder(tf.float64, shape=(), name="lr"), a.includeDiffuse)
            loss.createLossGraph()
            
            #Create the training graph part
            loss.createTrainVariablesGraph()


        #Organize the images we want to retrieve from the train network run
        display_fetches, converted_images = helpers.display_images_fetches(data.pathBatch, data.inputBatch, data.targetBatch, data.gammaCorrectedInputsBatch, model.output, a.nbTargets, a.logOutputAlbedos)
        if a.mode == "train":
            #Register inputs, targets, renderings and loss in Tensorboard        
            helpers.registerTensorboard(data.pathBatch, converted_images, a.maxImages, a.nbTargets, loss.lossValue, a.batch_size, loss.targetsRenderings, loss.outputsRenderings)

        #Compute how many paramters the network has
        with tf.name_scope("parameter_count"):
            parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
            
        #Initialize a saver
        saver = tf.train.Saver(max_to_keep=1)
        if a.checkpoint is not None:
            print("reading model from checkpoint : " + a.checkpoint)
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            partialSaver = helpers.optimistic_saver(checkpoint)
        logdir = a.output_dir if a.summary_freq > 0 else None
        sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
        #helpers.print_trainable()
        with sv.managed_session() as sess:
            print("parameter_count =", sess.run(parameter_count))
            
            #Loads the checkpoint
            if a.checkpoint is not None:
                print("restoring model from checkpoint : " + a.checkpoint)
                partialSaver.restore(sess, checkpoint)
            
            #Evaluate how many steps to run
            max_steps = 2**32
            if a.max_epochs is not None:
                max_steps = data.stepsPerEpoch * a.max_epochs
            if a.max_steps is not None:
                max_steps = a.max_steps

            #If we want to run a test
            if a.mode == "test" or a.mode == "eval":
                filesets = test(sess, data, max_steps, display_fetches, output_dir = a.output_dir)
                if runID == nbRun - 1 and runID >= 1: #If we are at the last iteration of the test, generate the full html
                    helpers.writeGlobalHTML(backupOutputDir, filesets, a.nbTargets, a.mode, a.maxImages)
            #If we want to train
            if a.mode == "train":
               train(sv, sess, data, max_steps, display_fetches, display_fetches_test, dataTest, saver, loss)
    def __renderInputs(self, materials, renderingScene, jitterLightPos,
                       jitterViewPos, mixMaterials, isTest, renderSize):
        mixedMaterial = materials
        if mixMaterials:
            alpha = tf.random_uniform([1],
                                      minval=0.1,
                                      maxval=0.9,
                                      dtype=tf.float32,
                                      name="mixAlpha")
            #print("mat2: " + str(materials2))

            materials1 = materials[::2]
            materials2 = materials[1::2]

            mixedMaterial = helpers.mixMaterials(materials1, materials2, alpha)
        mixedMaterial.set_shape(
            [None, self.nbTargetsToRead, renderSize, renderSize, 3])
        mixedMaterial = helpers.adaptRougness(mixedMaterial)
        #These 3 lines below tries to scale the albedos to get more variety and to randomly flatten the normals to disambiguate the normals and albedos. We did not see strong effect for these.
        #if not isTest and self.useAugmentationInRenderings:
        #    mixedMaterial = helpers.adaptAlbedos(mixedMaterial, self.batchSize)
        #    mixedMaterial = helpers.adaptNormals(mixedMaterial, self.batchSize)

        reshaped_targets_batch = helpers.target_reshape(
            mixedMaterial
        )  #reshape it to be compatible with the rendering algorithm [?, size, size, 12]
        nbRenderings = self.maxInputToRead
        if not self.fixImageNb:
            #If we don't want a constant number of input images, we randomly select a number of input images between 1 and the maximum number of images defined by the user.
            nbRenderings = tf.random_uniform([1],
                                             1,
                                             self.maxInputToRead + 1,
                                             dtype=tf.int32)[0]
        rendererInstance = renderer.GGXRenderer(includeDiffuse=True)
        ## Do renderings of the mixedMaterial

        targetstoRender = reshaped_targets_batch
        pixelsToAdd = 0

        targetstoRender = helpers.preprocess(
            targetstoRender)  #Put targets to -1; 1
        surfaceArray = helpers.generateSurfaceArray(
            renderSize, pixelsToAdd
        )  #Generate a grid Y,X between -1;1 to act as the pixel support of the rendering (computer the direction vector between each pixel and the light/view)

        #Do the renderings
        inputs = helpers.generateInputRenderings(
            rendererInstance,
            targetstoRender,
            self.batchSize,
            nbRenderings,
            surfaceArray,
            renderingScene,
            jitterLightPos,
            jitterViewPos,
            self.useAmbientLight,
            useAugmentationInRenderings=self.useAugmentationInRenderings)
        #inputs = [helpers.preprocess(input) for input in inputs]

        randomTopLeftCrop = tf.zeros([self.batchSize, nbRenderings, 2],
                                     dtype=tf.int32)
        averageCrop = 0.0

        #If we want to jitter the renderings around (to try to take into account small non alignment), we should handle the material crop a bit differently
        #We didn't really manage to get satisfying results with the jittering of renderings. But the code could be useful if this is of interest to Ansys.
        if self.jitterRenderings:
            randomTopLeftCrop = tf.random_normal(
                [self.batchSize, nbRenderings, 2], 0.0,
                1.0)  #renderSize - self.cropSize, dtype=tf.int32)
            randomTopLeftCrop = randomTopLeftCrop * tf.exp(
                tf.random_normal(
                    [self.batchSize], 0.0,
                    1.0))  #renderSize - self.cropSize, dtype=tf.int32)
            randomTopLeftCrop = randomTopLeftCrop - tf.reduce_mean(
                randomTopLeftCrop, axis=1, keep_dims=True)
            randomTopLeftCrop = tf.round(randomTopLeftCrop)
            randomTopLeftCrop = tf.cast(randomTopLeftCrop, dtype=tf.int32)
            averageCrop = tf.cast(self.maxJitteringPixels * 0.5,
                                  dtype=tf.int32)
            randomTopLeftCrop = randomTopLeftCrop + averageCrop
            randomTopLeftCrop = tf.clip_by_value(randomTopLeftCrop, 0,
                                                 self.maxJitteringPixels)

        totalCropSize = self.cropSize

        inputs, targets = helpers.cutSidesOut(inputs, targetstoRender,
                                              randomTopLeftCrop, totalCropSize,
                                              self.firstAsGuide, averageCrop)
        print("inputs shape after" + str(inputs.get_shape()))

        self.gammaCorrectedInputsBatch = inputs
        tf.summary.image("GammadInputs",
                         helpers.convert(inputs[0, :]),
                         max_outputs=5)
        inputs = tf.pow(inputs, 2.2)  # correct gamma
        if self.logInput:
            inputs = helpers.logTensor(inputs)

        inputs = helpers.preprocess(inputs)
        targets = helpers.target_deshape(targets, self.nbTargetsToRead)
        return targets, inputs
Exemple #6
0
    def populateInNetworkFeedGraphSpatialMix(self,renderingScene, shuffle = True, imageSize = 512, useSpatialMix = True):
        with tf.name_scope("load_images"):
            #Create a tensor out of the list of paths
            filenamesTensor = tf.constant(self.pathList)
            #Reads a slice of the tensor, for example, if the tensor is of shape [100,2], the slice shape should be [2] (to check if we have problem here)
            dataset = tf.data.Dataset.from_tensor_slices(filenamesTensor)

            #for each slice apply the __readImages function
            dataset = dataset.map(self.__readImagesGT, num_parallel_calls=int(multiprocessing.cpu_count() / 4))
            #Authorize repetition of the dataset when one epoch is over.
            #shuffle = True
            if shuffle:
               dataset = dataset.shuffle(buffer_size=16, reshuffle_each_iteration=True)
            #set batch size
            dataset = dataset.repeat()
            toPull = self.batchSize 
            if useSpatialMix:
                toPull = self.batchSize * 2
            batched_dataset = dataset.batch(toPull)
            batched_dataset = batched_dataset.prefetch(buffer_size=4)
            #Create an iterator to be initialized
            iterator = batched_dataset.make_initializable_iterator()

            #Create the node to retrieve next batch
            paths_batch, targets_batch = iterator.get_next()
            inputRealSize = imageSize#Should be input image size but changed tmp
            
            if useSpatialMix:
                threshold = 0.5
                perlinNoise = tf.expand_dims(tf.expand_dims(helpers.generate_perlin_noise_2d((inputRealSize, inputRealSize), (1,1)), axis = -1), axis = 0)
                perlinNoise = (perlinNoise + 1.0) * 0.5
                perlinNoise = perlinNoise >= threshold
                perlinNoise = tf.cast(perlinNoise, tf.float32)
                inverted = 1.0 - perlinNoise

                materialsMixed1 = targets_batch[::2] * perlinNoise
                materialsMixed2 = targets_batch[1::2] * inverted
                
                fullSizeMixedMaterial = materialsMixed1 + materialsMixed2
                targets_batch = fullSizeMixedMaterial
                paths_batch = paths_batch[::2]                
                
            targetstoRender = helpers.target_reshape(targets_batch) #reshape it to be compatible with the rendering algorithm [?, size, size, 12]
            nbRenderings = 1
            rendererInstance = renderer.GGXRenderer(includeDiffuse = True)
            ## Do renderings of the mixedMaterial
            mixedMaterial = helpers.adaptRougness(targetstoRender)

            targetstoRender = helpers.preprocess(targetstoRender) #Put targets to -1; 1
            surfaceArray = helpers.generateSurfaceArray(inputRealSize)

            inputs_batch = helpers.generateInputRenderings(rendererInstance, targetstoRender, self.batchSize, nbRenderings, surfaceArray, renderingScene, False, False, self.useAmbientLight, useAugmentationInRenderings = self.useAugmentationInRenderings)
            
            targets_batch = helpers.target_deshape(targetstoRender, self.nbTargetsToRead)
            self.gammaCorrectedInputsBatch =  tf.squeeze(inputs_batch, [1])
            #tf.summary.image("GammadInputs", helpers.convert(inputs[0, :]), max_outputs=5)
            inputs_batch = tf.pow(inputs_batch, 2.2) # correct gamma
            if self.logInput:
                inputs_batch = helpers.logTensor(inputs_batch)

            #Do the random crop, if the crop if fix, crop in the middle
            if inputRealSize > self.tileSize:
                if self.fixCrop:
                    xyCropping = (inputRealSize - self.tileSize) // 2
                    xyCropping = [xyCropping, xyCropping]
                else:
                    xyCropping = tf.random_uniform([1], 0, inputRealSize - self.tileSize, dtype=tf.int32)

                inputs_batch = inputs_batch[:, :, xyCropping[0] : xyCropping[0] + self.tileSize, xyCropping[0] : xyCropping[0] + self.tileSize, :]
                targets_batch = targets_batch[:,:, xyCropping[0] : xyCropping[0] + self.tileSize, xyCropping[0] : xyCropping[0] + self.tileSize, :]

            #Set shapes
            inputs_batch = tf.squeeze(inputs_batch, [1]) #Before this the input has a useless dimension in 1 as we have only 1 rendering
            inputs_batch.set_shape([None, self.tileSize, self.tileSize, 3])
            targets_batch.set_shape([None, self.nbTargetsToRead, self.tileSize, self.tileSize, 3])
            
            #Populate the object
            self.stepsPerEpoch = int(math.floor(len(self.pathList) / self.batchSize))
            self.inputBatch = inputs_batch
            self.targetBatch = targets_batch
            self.iterator = iterator
            self.pathBatch = paths_batch