Exemplo n.º 1
0
def SampleTrain():

    path = os.path.join(".", "data_newOK")
    path2 = os.path.join(".", "niftiImage/split/Test")

    filenames = os.listdir(path)
    filenames.sort()

    filenames2 = os.listdir(path2)
    filenames2.sort()
    #print(filenames2)

    imagefile = os.path.join(path, filenames[0])
    b = []
    for name in filenames2:
        imagefile2 = os.path.join(path2, name)

        img = nib.load(imagefile)
        img2 = nib.load(imagefile2)
        data = img.get_fdata()
        data2 = img2.get_fdata()

        #histogram(data2)
        a = binarisation(data2, 0.5)
        #print(data)
        #print(np.array_equal(a, data.astype(bool)))
        #print(data.dtype)
        affine = load_3d.get_affine_3d()
        save_img = nib.Nifti1Image(a, affine)

        #nib.save(save_img, 'check/1000.nii.gz')
        diff = DICE(a, data)
        b.append(diff)
        print(diff)  #
        c = np.asarray(b)
Exemplo n.º 2
0
def reset_Data():

    path = os.path.join(".", "data_newOK")
    path2 = os.path.join(".", "niftiImage/genIter/Sample")
    id = os.path.join(".", "BatchGeneratorClass/config/k-fold_config")
    #for filename in os.listdir(path):
    #     prefix, left = filename.split('_', maxsplit=1)
    #     num = left.split('_', maxsplit=1)
    #     if len(num) == 1 :
    #         num1 = num[0]
    #         num1 = num1.zfill(3)
    #         new_filename = prefix + "_" + num1
    #     else:
    #         num1 = num[0]
    #         num1 = num1.zfill(3)
    #         rest = num[1]
    #         new_filename = prefix + "_" + num1 + "_" + rest
    #     os.rename(os.path.join(path, filename), os.path.join(path, new_filename))
    affine = load_3d.get_affine_3d()
    i = 0
    for file in os.listdir(path2):
        imagefile2 = os.path.join(path2, file)
        img2 = nib.load(imagefile2)
        data2 = img2.get_data()
        save_img = nib.Nifti1Image(data2, affine)

        nib.save(save_img, 'niftiImage/genIter/SampleReset/%d.nii.gz' % i)
        i += 1
Exemplo n.º 3
0
def binarise_data(path2, fold_no, threshold):
    #path = os.path.join(".", "data_newOK")
    #path2 = os.path.join(".", "niftiImage/genIterReal/Sample")
    #id = os.path.join(".", "BatchGeneratorClass/config/k-fold_config")

    affine = load_3d.get_affine_3d()
    i = 0
    for file in os.listdir(path2):
        imagefile2 = os.path.join(path2, file)
        img2 = nib.load(imagefile2)
        data2 = img2.get_data()

        a = binarisation(data2, threshold)
        save_img = nib.Nifti1Image(a, affine)
        nib.save(
            save_img,
            'niftiImage/splitGenIter3/SampleBinaryMask/%d/%d.nii.gz' %
            (fold_no + 1, i))
        i += 1
Exemplo n.º 4
0
def clean_binarise(binary_path, fold_no):
    affine = load_3d.get_affine_3d()

    i = 0
    for file in os.listdir(binary_path):
        imagefile2 = os.path.join(binary_path, file)
        img2 = nib.load(imagefile2)
        data2 = img2.get_data()

        points = np.transpose(np.where(data2))
        points = remove_outliers(points)

        #using DBSCAN
        data = np.zeros((96, 96, 96))
        data[points[:, 0], points[:, 1], points[:, 2]] = 1.
        save_img = nib.Nifti1Image(data, affine)
        nib.save(
            save_img, 'niftiImage/splitGenIter3/CleanBinaryMask/%d/%d.nii.gz' %
            (fold_no + 1, i))
        i += 1
Exemplo n.º 5
0
def full_binarise(path2, fold_no, fold=False):
    affine = load_3d.get_affine_3d()

    i = 0
    for file in os.listdir(path2):
        imagefile2 = os.path.join(path2, file)
        img2 = nib.load(imagefile2)
        data2 = img2.get_data()
        a = binarisation(data2, 0.5)
        points = np.transpose(np.where(a))
        points = remove_outliers(points)

        a = np.zeros((96, 96, 96))
        a[points[:, 0], points[:, 1], points[:, 2]] = 1.

        save_img = nib.Nifti1Image(a, affine)
        nib.save(
            save_img,
            'niftiImage/splitGenIterSave/CleanBinaryMask/%d/%d.nii.gz' %
            (fold_no + 1, i))
        i += 1
Exemplo n.º 6
0
def DBSCAN_binarise(path2, fold_no, fold=False):
    affine = load_3d.get_affine_3d()

    i = 0
    for file in os.listdir(path2):
        imagefile2 = os.path.join(path2, file)
        img2 = nib.load(imagefile2)
        data2 = img2.get_data()
        a = binarisation(data2, 0.5)
        points = np.transpose(np.where(a))
        points = remove_outliers(points)

        #using DBSCAN
        points, noise = DBSCAN_outliers(points)
        data = np.zeros((96, 96, 96))
        data[points[:, 0], points[:, 1], points[:, 2]] = 1.
        data[noise[:, 0], noise[:, 1], noise[:, 2]] = 0.

        save_img = nib.Nifti1Image(data, affine)
        #nib.save(save_img, 'niftiImage/1axis/DBSCANBinaryMask/%d/%d.nii.gz' %(fold_no+1, i))
        i += 1
Exemplo n.º 7
0
    def train_k_folds(self, files, fold_no, epochs, batch_size=1, sample_interval=100, d_iter=0, g_iter=0, save=True):

        #assert folds > 0, "Need at least 1 fold to run"

        #for fold_no in range(folds):
        print("Fold: %d\n" %(fold_no+1))

        #load size of train and test set ################################
        x_train_real, x_test_real = load_3d.load_US_3d_fold(fold_no)
        x1_train_real, x1_test_real, x2_train_real, x2_test_real, x3_train_real, x3_test_real = load_3d.load_2d_fold(fold_no)
        affine = load_3d.get_affine_3d()


        ############ location of tensorboard log ########################################
        log_path = files['logs'] + "%d" %(fold_no+1)
        writer = tf.summary.FileWriter(log_path)

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        #start Batch Generator
        confTrain = {}
        config_path = files['k_config_path'] + "%d.cfg" %(fold_no+1)
        print("start\n")
        if sys.version_info[0] < 3:
            execfile(config_path, confTrain)
        else:
            exec(open(config_path).read(),confTrain)


        assert epochs == confTrain['numEpochs'], "Number of epochs should be \
        similar as in the config file (%d)" %(confTrain['numEpochs'])

        assert batch_size == confTrain['batchSizeTraining'],  "Batch size \
        should be similar as in the config file (%d)" %(confTrain['batchSizeTraining'])


        train_length = 0
        val_length = 0

        with open(confTrain['channelsTraining'][0]) as f:
            for line in f:
                train_length += 1

        with open(confTrain['channelsValidation_ID']) as f:

            for line in f:
                val_length += 1

        batchGen = BatchGeneratorVolAnd2DplanesMultiThread(confTrain, mode='training', infiniteLoop=False, maxQueueSize = 5)
        #Validation Batch
        batchGenV = BatchGeneratorVolAnd2DplanesMultiThread(confTrain, mode='validation', infiniteLoop=False, maxQueueSize = 4)
        batchGen.generateBatches()
        batchGenV.generateBatches()

        for epoch in range(epochs):
            avg_d_cost = 0
            avg_g_cost = 0
            avg_acc = 0
            total_batches = int(train_length / batch_size)

            for batch_no in range(total_batches):

                x_train, _ , x1_train, x2_train, x3_train = batchGen.getBatchAndScalingFactor()
                x_train = np.rollaxis(x_train, 1, 5)
                x1_train = np.rollaxis(x1_train, 1, 4)
                x2_train = np.rollaxis(x2_train, 1, 4)
                x3_train = np.rollaxis(x3_train, 1, 4)



                imgs = x_train
                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Sample noise as generator input
                #noise = np.random.normal(0, 1, (batch_size, 64))
                noise = np.random.normal(0, 1, (batch_size, 96,96))
                noise = np.reshape(noise,(batch_size,96,96,1))

                # Generate a half batch of new images
                gen_imgs = self.generator.predict([noise, x2_train, x3_train])

                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch([imgs, x2_train, x3_train], valid)
                d_loss_fake = self.discriminator.train_on_batch([gen_imgs, x2_train, x3_train], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                avg_d_cost += d_loss[0]/(total_batches*(d_iter+1))
                avg_acc += d_loss[1]/(total_batches*(d_iter+1))

                #Train discriminator an extra d_iter time
                for iteration in range(d_iter):
                    d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
                    d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
                    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                    avg_d_cost += d_loss[0]/(total_batches*(d_iter+1))
                    avg_acc += d_loss[1]/(total_batches*(d_iter+1))
                # ---------------------
                #  Train Generator
                # ---------------------

                # Train the generator
                g_loss = self.combined.train_on_batch([noise, x2_train, x3_train], valid)
                avg_g_cost += g_loss/(total_batches*(g_iter + 1))

                ######### Generator iterations #################################
                for i in range(g_iter):
                    g_loss = self.combined.train_on_batch([noise, x2_train, x3_train], valid)
                    avg_g_cost += g_loss/(total_batches*(g_iter + 1))

            # Calculate validation loss after 1 epoch
            val_loss = 0

            for i in range(val_length):
                #get noise
                noise = np.random.normal(0, 1, (1, 96,96))
                noise = np.reshape(noise,(1,96,96,1))

                #Get condition
                _, x_val, x1_val, x2_val, x3_val = batchGenV.getBatchAndScalingFactor()
                x_val = np.rollaxis(x_val, 1, 5)
                #x1 = np.rollaxis(x1_val, 1, 4)
                x2 = np.rollaxis(x2_val, 1, 4)
                x3 = np.rollaxis(x3_val, 1, 4)

                gen_imgs = self.generator.predict([noise, x2, x3])
                gen_imgs = np.reshape(gen_imgs, (96,96,96))
                val_loss += soft_dice_loss(gen_imgs, x_val)/(val_length)


            print("%d: [D loss %f, acc.: %.2f%%] [G loss: %f]" % (epoch, avg_d_cost, 100*avg_acc, avg_g_cost))
            print("Epoch: %d, val loss: %f" % (epoch, val_loss))

            ######## add to logs ############################################
            summary = tf.Summary(value=[tf.Summary.Value(tag="d_loss",simple_value=avg_d_cost),])
            writer.add_summary(summary, global_step=epoch)
            summary2 = tf.Summary(value=[tf.Summary.Value(tag="acc",simple_value=avg_acc),])
            writer.add_summary(summary2, global_step=epoch)
            summary3 = tf.Summary(value=[tf.Summary.Value(tag="g_loss", simple_value=avg_g_cost)])
            summary4 = tf.Summary(value=[tf.Summary.Value(tag="val_loss",simple_value=val_loss),])
            writer.add_summary(summary3, global_step=epoch)
            writer.add_summary(summary4, global_step=epoch)
            ##############################################################

            # If at save interval => save generated image samples
            if sample_interval != 0 and epoch % sample_interval == 0:
                self.sample_images(epoch,x_train_real[0], x1_train_real[0], x2_train_real[0], x3_train_real[0], affine,
                 fold_no, save_path=files['sample_path'])

            if save == True and epochs - epoch == 1:
                self.sample_all(x_test_real, x1_test_real, x2_test_real, x3_test_real, affine, fold_no,
                save_path=files['predict_path'], model_path=files['models'])

        batchGen.finish()
        batchGenV.finish()
Exemplo n.º 8
0
    def train(self, files, epochs, batch_size=1, sample_interval=100, d_iter=0, g_iter=0):


        _, x_test_real = load_3d.load_US_3d_fold()
        _, x1_test_real, _, x2_test_real, _, x3_test_real = load_3d.load_2d_fold()
        affine = load_3d.get_affine_3d()

        writer = tf.summary.FileWriter(files['logs'])

        #start Batch Generator
        confTrain = {}
        print("start\n")
        if sys.version_info[0] < 3:
            execfile(files['config_path'], confTrain)
        else:
            exec(open(files['config_path']).read(),confTrain)

        assert epochs == confTrain['numEpochs'], "Number of epochs should be \
        similar as in the config file (%d)" %(confTrain['numEpochs'])

        assert batch_size == confTrain['batchSizeTraining'],  "Batch size \
        should be similar as in the config file (%d)" %(confTrain['batchSizeTraining'])


        train_length = 0
        val_length = 0

        with open(confTrain['channelsTraining'][0]) as f:
            for line in f:
                train_length += 1

        with open(confTrain['channelsValidation_ID']) as f:

            for line in f:
                val_length += 1

        batchGen = BatchGeneratorVolAnd2DplanesMultiThread(confTrain, mode='training', infiniteLoop=False, maxQueueSize = 5)
        batchGen.generateBatches()

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))



        for epoch in range(epochs):
            avg_d_cost = 0
            avg_g_cost = 0
            avg_acc = 0
            total_batches = int(train_length / batch_size)

            for batch_no in range(total_batches):
            #batch generator

                x_train, _ , x1_train, x2_train, x3_train = batchGen.getBatchAndScalingFactor()

                # Move channels to last axis
                x_train = np.rollaxis(x_train, 1, 5)
                x1_train = np.rollaxis(x1_train, 1, 4)
                x2_train = np.rollaxis(x2_train, 1, 4)
                x3_train = np.rollaxis(x3_train, 1, 4)

                imgs = x_train
                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Sample noise as generator input

                noise = np.random.normal(0, 1, (batch_size, 96,96))
                noise = np.reshape(noise,(batch_size,96,96,1))

                # Generate a half batch of new images
                gen_imgs = self.generator.predict([noise, x2_train, x3_train])

                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch([imgs, x2_train, x3_train], valid)
                d_loss_fake = self.discriminator.train_on_batch([gen_imgs, x2_train, x3_train], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                avg_d_cost += d_loss[0]/(total_batches*(d_iter+1))
                avg_acc += d_loss[1]/(total_batches*(d_iter+1))

                #Train discriminator twice
                for iterations in range(d_iter):
                    d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
                    d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
                    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                    avg_d_cost += d_loss[0]/(total_batches*(d_iter+1))
                    avg_acc += d_loss[1]/(total_batches*(d_iter+1))

                # ---------------------
                #  Train Generator
                # ---------------------

                # Train the generator
                g_loss = self.combined.train_on_batch([noise, x2_train, x3_train], valid)
                for i in range(g_iter):
                    g_loss = self.combined.train_on_batch([noise, x2_train, x3_train], valid)
                    avg_g_cost += g_loss/(total_batches*(g_iter + 1))


            print("%d: [D loss %f, acc.: %.2f%%] [G loss: %f]" % (epoch, avg_d_cost, 100*avg_acc, avg_g_cost))
            #print("Epoch: %d, val loss: %f" % (epoch, val_loss))


            summary = tf.Summary(value=[tf.Summary.Value(tag="d_loss",simple_value=avg_d_cost),])
            writer.add_summary(summary, global_step=epoch)
            summary2 = tf.Summary(value=[tf.Summary.Value(tag="acc",simple_value=avg_acc),])
            writer.add_summary(summary2, global_step=epoch)
            summary3 = tf.Summary(value=[tf.Summary.Value(tag="g_loss", simple_value=avg_g_cost)])
            writer.add_summary(summary3, global_step=epoch)

            if epochs - epoch == 1:
                self.sample_all(x_test_real, x1_test_real, x2_test_real, x3_test_real, affine,
                                save_path=files['predict_path'], model_path=files['models'])
                #os._exit(1)

        batchGen.finish()
Exemplo n.º 9
0
    def train(self,
              fold_no,
              epochs,
              batch_size=1,
              sample_interval=50,
              iter=False,
              folds=1):

        # Load the dataset

        x_train_real, x_test_real = load_3d.load_US_3d_fold(fold_no)
        x1_train_real, x1_test_real, x2_train_real, x2_test_real, x3_train_real, x3_test_real = load_3d.load_2d_fold(
            fold_no)
        affine = load_3d.get_affine_3d()
        # # Configure input
        # x_train = np.reshape(x_train, (len(x_train), 96, 96, 96, 1))  # adapt this if using `channels_first` image data format
        # x_test = np.reshape(x_test, (len(x_test), 96, 96, 96, 1))
        #
        # #conditioning
        # x_condition = np.dstack((x1_train, x2_train))
        # x_condition = np.dstack((x_condition,x3_train))

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        writer = tf.summary.FileWriter('./logs/splitGenIterSave/')

        #start Batch Generator
        confTrain = {}
        print("start\n")
        if sys.version_info[0] < 3:
            execfile(
                "/homes/wt814/IndividualProject/code/BatchGeneratorClass/BatchGenerator_v2_config_Full.cfg",
                confTrain)
        else:
            exec(
                open(
                    "/homes/wt814/IndividualProject/code/BatchGeneratorClass/BatchGenerator_v2_config_Full.cfg"
                ).read(), confTrain)

        try:
            epochs <= confTrain['numEpochs']
        except:
            sys.exit(1)

        try:
            batch_size == confTrain['batchSizeTraining']
        except:
            sys.exit(1)

        batchGen = BatchGeneratorVolAnd2DplanesMultiThread(confTrain,
                                                           mode='training',
                                                           infiniteLoop=False,
                                                           maxQueueSize=5)
        #batchGenV = BatchGeneratorVolAnd2DplanesMultiThread(confTrain, mode='validation', infiniteLoop=False, maxQueueSize = 4)
        batchGen.generateBatches()
        #batchGenV.generateBatches()
        for epoch in range(epochs):
            avg_d_cost = 0
            avg_g_cost = 0
            avg_acc = 0
            total_batches = int(120 / batch_size)

            for batch_no in range(total_batches):
                #batch generator

                x_train, x_test, x1_train, x2_train, x3_train = batchGen.getBatchAndScalingFactor(
                )
                x_train = np.rollaxis(x_train, 1, 5)
                x1_train = np.rollaxis(x1_train, 1, 4)
                x2_train = np.rollaxis(x2_train, 1, 4)
                x3_train = np.rollaxis(x3_train, 1, 4)
                #condition
                #x_condition = np.concatenate((x1_train, x2_train), axis=-1)
                #x_condition = np.concatenate((x_condition,x3_train), axis=-1)

                imgs = x_train
                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Select a random half batch of images
                # idx = np.random.randint(0, x_train.shape[0], batch_size)
                # imgs, labels = x_train[idx], x_condition[idx]

                # Sample noise as generator input
                #noise = np.random.normal(0, 1, (batch_size, 64))
                noise = np.random.normal(0, 1, (batch_size, 96, 96))
                noise = np.reshape(noise, (batch_size, 96, 96, 1))

                #noise = np.reshape(noise,(batch_size,96,96,))

                # Generate a half batch of new images
                gen_imgs = self.generator.predict(
                    [noise, x1_train, x2_train, x3_train])

                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch(
                    [imgs, x1_train, x2_train, x3_train], valid)
                d_loss_fake = self.discriminator.train_on_batch(
                    [gen_imgs, x1_train, x2_train, x3_train], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                #summary = tf.Summary(value=[tf.Summary.Value(tag="d_loss_real",simple_value=d_loss_real[0]),])
                #summary1 = tf.Summary(value=[tf.Summary.Value(tag="d_loss_fake",simple_value=d_loss_fake[0]),])
                #writer.add_summary(summary, global_step=epoch)
                #writer.add_summary(summary1, global_step=epoch)
                #summary2 = tf.Summary(value=[tf.Summary.Value(tag="acc",simple_value=d_loss[1]),])
                #writer.add_summary(summary2, global_step=epoch)
                avg_d_cost += d_loss[0] / total_batches
                avg_acc += d_loss[1] / total_batches

                #Train discriminator twice
                if iter == True:
                    d_loss_real = self.discriminator.train_on_batch(
                        [imgs, labels], valid)
                    d_loss_fake = self.discriminator.train_on_batch(
                        [gen_imgs, labels], fake)
                    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag="d_loss_real",
                                         simple_value=d_loss_real[0]),
                    ])
                    summary1 = tf.Summary(value=[
                        tf.Summary.Value(tag="d_loss_fake",
                                         simple_value=d_loss_fake[0]),
                    ])
                    writer.add_summary(summary, global_step=epoch)
                    writer.add_summary(summary1, global_step=epoch)
                    summary2 = tf.Summary(value=[
                        tf.Summary.Value(tag="acc", simple_value=d_loss[1]),
                    ])
                    writer.add_summary(summary2, global_step=epoch)
                # ---------------------
                #  Train Generator
                # ---------------------

                # Train the generator
                g_loss = self.combined.train_on_batch(
                    [noise, x1_train, x2_train, x3_train], valid)
                iter_gen = 1
                for i in range(iter_gen):
                    g_loss = self.combined.train_on_batch(
                        [noise, x1_train, x2_train, x3_train], valid)
                    avg_g_cost += g_loss / (total_batches * (iter_gen + 1))
                else:
                    avg_g_cost += g_loss / total_batches

            # Calculate validation loss after 1 epoch
            # val_loss = 0
            #
            # for i in range(30):
            #     #noise = np.random.normal(0, 1, (1, 16))
            #     noise = np.random.normal(0, 1, (1, 96,96))
            #     noise = np.reshape(noise,(1,96,96,1))
            #     #x_condition = np.concatenate((x1_test_real[i], x2_test_real[i]), axis=-1)
            #     #x_condition = np.concatenate((x_condition,x3_test_real[i]), axis=-1)
            #     #x_condition = np.reshape(x_condition, (1,96,96,3))
            #     #sampled_labels = x_condition
            #     _, x_val, x1_val, x2_val, x3_val = batchGenV.getBatchAndScalingFactor()
            #     x_val = np.rollaxis(x_val, 1, 5)
            #     x1 = np.rollaxis(x1_val, 1, 4)
            #     x2 = np.rollaxis(x2_val, 1, 4)
            #     x3 = np.rollaxis(x3_val, 1, 4)
            #
            #     gen_imgs = self.generator.predict([noise, x1, x2, x3])
            #     gen_imgs = np.reshape(gen_imgs, (96,96,96))
            #     val_loss += soft_dice_loss(gen_imgs, x_val)/(30)

            print("%d: [D loss %f, acc.: %.2f%%] [G loss: %f]" %
                  (epoch, avg_d_cost, 100 * avg_acc, avg_g_cost))
            #print("Epoch: %d, val loss: %f" % (epoch, val_loss))

            summary = tf.Summary(value=[
                tf.Summary.Value(tag="d_loss", simple_value=avg_d_cost),
            ])
            #summary1 = tf.Summary(value=[tf.Summary.Value(tag="d_loss_fake",simple_value=avg_d_fake_cost),])
            writer.add_summary(summary, global_step=epoch)
            #writer.add_summary(summary1, global_step=epoch)
            summary2 = tf.Summary(value=[
                tf.Summary.Value(tag="acc", simple_value=avg_acc),
            ])
            writer.add_summary(summary2, global_step=epoch)
            summary3 = tf.Summary(value=[
                tf.Summary.Value(tag="g_loss", simple_value=avg_g_cost)
            ])
            #summary4 = tf.Summary(value=[tf.Summary.Value(tag="val_loss",simple_value=val_loss),])
            writer.add_summary(summary3, global_step=epoch)
            #writer.add_summary(summary4, global_step=epoch)
            # If at save interval => save generated image samples
            # if epoch % sample_interval == 0:
            #     self.sample_images(epoch,x_train_real[0], x1_train_real[0], x2_train_real[0], x3_train_real[0], affine, fold_no)

            if epochs - epoch == 1:
                self.sample_all(x_test_real, x1_test_real, x2_test_real,
                                x3_test_real, affine, fold_no)
                #os._exit(1)

        batchGen.finish()