Beispiel #1
0
def transfrom_data(NUM_TEST_PER_EPOCH):
    g2=tf.Graph()
    g3=tf.Graph()
    encoded_data=[]
    original_data=[]
    val_data=[]
    val_original=[]
    train_loader,val_loader= get_data(BATCH_SIZE,BATCH_SIZE)
    trainiter = iter(train_loader)
    testiter = iter(val_loader)
    
    with g2.as_default():
        with tf.Session() as sess2:
            sargan_model=SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver= tf.train.Saver()    
            sargan_saver = tf.train.import_meta_graph(trained_model_path2+'/sargan_mnist.meta');
            sargan_saver.restore(sess2,tf.train.latest_checkpoint(trained_model_path2));
            
            for i in range(NUM_ITERATION):
                features2, labels = next(trainiter)
                features2 = features2.data.numpy().transpose(0,2,3,1)
                features=np.zeros([len(features2),img_size[0],img_size[1],1])
                original_images=np.zeros([len(features2),img_size[0],img_size[1],1])
                for k in range(len(features)):
                    features[k]=add_gaussian_noise(features2[k], sd=np.random.uniform(NOISE_STD_RANGE[0], NOISE_STD_RANGE[1]))
                    original_images[k,:,:,0]=features2[k,:,:,0]
                processed_batch=sess2.run(sargan_model.gen_img,feed_dict={sargan_model.image: features, sargan_model.cond: features})
                encoded_data.append(processed_batch)
                original_data.append(original_images)
                
            for i in range(NUM_TEST_PER_EPOCH):
                features2, labels = next(testiter)
                features2 = features2.data.numpy().transpose(0,2,3,1)
                features=np.zeros([len(features2),img_size[0],img_size[1],1])
                original_images=np.zeros([len(features2),img_size[0],img_size[1],1])
                for k in range(len(features)):
                    features[k]=add_gaussian_noise(features2[k], sd=np.random.uniform(NOISE_STD_RANGE[0], NOISE_STD_RANGE[1]))
                    original_images[k,:,:,0]=features2[k,:,:,0]
                processed_batch=sess2.run(sargan_model.gen_img,feed_dict={sargan_model.image: features, sargan_model.cond: features})
                val_original.append(original_images)
                val_data.append(processed_batch)
                
    with g3.as_default():
        with tf.Session() as sess3:
            sargan_model2=SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver2= tf.train.Saver()    
            sargan_saver2 = tf.train.import_meta_graph(trained_model_path3+'/sargan_mnist.meta');
            sargan_saver2.restore(sess3,tf.train.latest_checkpoint(trained_model_path3));
            for i in range(NUM_ITERATION):
                encoded_data[i]=sess3.run(sargan_model2.gen_img,feed_dict={sargan_model2.image: encoded_data[i], sargan_model2.cond: encoded_data[i]})
            for i in range(NUM_TEST_PER_EPOCH):
                val_data[i]=sess3.run(sargan_model2.gen_img,feed_dict={sargan_model2.image: val_data[i], sargan_model2.cond: val_data[i]})

    return encoded_data, original_data ,val_data, val_original
def transfrom_data(NUM_TEST_PER_EPOCH):
    encoded_data=[]
    original_data=[]
    train_loader= get_data(BATCH_SIZE)
    trainiter = iter(train_loader)
    for i in range(NUM_ITERATION):
        features2, labels = next(trainiter)
        features2 = np.array(tf.keras.applications.resnet.preprocess_input(features2.data.numpy().transpose(0,2,3,1)*255))
        features=np.zeros([len(features2),img_size[0],img_size[1],img_size[2]])
        features[:,:,:,0]=(features2[:,:,:,1])
        features=np.clip(features+116.779,0,255)/255
        original_images=np.copy(features)
        for k in range(len(features)):
            features[k]=add_gaussian_noise(features[k], sd=np.random.uniform(NOISE_STD_RANGE[0], NOISE_STD_RANGE[1]))
        encoded_data.append(np.copy(features))
        original_data.append(np.copy(original_images))
                
    gpu_options = tf.GPUOptions(allow_growth=True, visible_device_list=str(GPU_ID))
    config = tf.ConfigProto(gpu_options=gpu_options)
    gx1 = tf.Graph()
    with gx1.as_default():
        with tf.Session(config=config) as sess1:
             sargan_model1=SARGAN(img_size, BATCH_SIZE, img_channel=1)
             sargan_saver1= tf.train.Saver()    
             sargan_saver1 = tf.train.import_meta_graph(trained_model_path1+'/sargan_mnist.meta');
             sargan_saver1.restore(sess1,tf.train.latest_checkpoint(trained_model_path1));
             for ibatch in range(NUM_TEST_PER_EPOCH):
                 processed_batch=sess1.run(sargan_model1.gen_img,feed_dict={sargan_model1.image: encoded_data[ibatch], sargan_model1.cond: encoded_data[ibatch]})
                 encoded_data[ibatch]=(np.copy(processed_batch))

    return encoded_data, original_data
Beispiel #3
0
def transfrom_data(NUM_TEST_PER_EPOCH):
    g2 = tf.Graph()
    encoded_data = []
    original_data = []
    val_data = []
    val_original = []
    train_loader, val_loader = get_data(BATCH_SIZE, BATCH_SIZE)
    trainiter = iter(train_loader)
    testiter = iter(val_loader)

    with g2.as_default():
        with tf.Session() as sess2:

            for i in range(NUM_ITERATION):
                features2, labels = next(trainiter)
                features2 = features2.data.numpy().transpose(0, 2, 3, 1)
                features = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                original_images = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                for k in range(len(features)):
                    features[k] = add_gaussian_noise(features2[k],
                                                     sd=np.random.uniform(
                                                         NOISE_STD_RANGE[0],
                                                         NOISE_STD_RANGE[1]))
                    original_images[k, :, :, 0] = features2[k, :, :, 0]
                encoded_data.append(features)
                original_data.append(original_images)

            for i in range(NUM_TEST_PER_EPOCH):
                features2, labels = next(testiter)
                features2 = features2.data.numpy().transpose(0, 2, 3, 1)
                features = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                original_images = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                for k in range(len(features)):
                    features[k] = add_gaussian_noise(features2[k],
                                                     sd=np.random.uniform(
                                                         NOISE_STD_RANGE[0],
                                                         NOISE_STD_RANGE[1]))
                    original_images[k, :, :, 0] = features2[k, :, :, 0]
                val_data.append(features)
                val_original.append(original_images)

    return encoded_data, original_data, val_data, val_original
def blur(batch):
    #sargan model
    #reshaping the images to a square
    newbatch=batch.reshape([len(batch),img_size[0],img_size[1],img_size[2]])
    #upping the size of the image#corrupting
    corruptedbatch=np.zeros([len(newbatch),img_size[0],img_size[1],img_size[2]])
    for i in range(len(newbatch)):
        corruptedbatch[i]=gaussian_filter(np.array([add_gaussian_noise(newbatch[i], sd=np.random.uniform(NOISE_STD_RANGE[0], NOISE_STD_RANGE[1]))]), sigma=2)[0]
    return corruptedbatch
def get_processed_data(x_batch, img_size, image_index):
    NOISE_STD_RANGE = [0.0, 0.02]
    x_batch = np.expand_dims(
        add_gaussian_noise(x_batch, sd=NOISE_STD_RANGE[1]), 4)
    channel1 = []
    channel2 = []
    channel3 = []
    num_batches = math.floor(len(x_batch) / 64)
    differnce = len(x_batch) - num_batches * 64
    num_batches2 = num_batches
    batch1 = np.zeros([64, img_size[0], img_size[1], img_size[2]])
    for i in range(num_batches):
        channel1.append(np.copy(x_batch[(64 * i):(64 * i + 64), :, :, 0]))
        channel2.append(np.copy(x_batch[(64 * i):(64 * i + 64), :, :, 1]))
        channel3.append(np.copy(x_batch[(64 * i):(64 * i + 64), :, :, 2]))

    if (differnce != 0):
        batch1[0:differnce, :, :, :] = x_batch[(64 * num_batches):(
            64 * num_batches + differnce), :, :, 0]
        channel1.append(np.copy(batch1))
        batch1[0:differnce, :, :, :] = x_batch[(64 * num_batches):(
            64 * num_batches + differnce), :, :, 1]
        channel2.append(np.copy(batch1))
        batch1[0:differnce, :, :, :] = x_batch[(64 * num_batches):(
            64 * num_batches + differnce), :, :, 2]
        channel3.append(np.copy(batch1))
        num_batches2 += 1
    channel1, im11, im21, im31, im41 = get_channel1(channel1, img_size,
                                                    num_batches2, image_index)
    channel2, im12, im22, im32, im42 = get_channel2(channel2, img_size,
                                                    num_batches2, image_index)
    channel3, im13, im23, im33, im43 = get_channel3(channel3, img_size,
                                                    num_batches2, image_index)

    im1 = np.stack((im13, im12, im11), axis=2)
    im2 = np.stack((im23, im22, im21), axis=2)
    im3 = np.stack((im33, im32, im31), axis=2)
    im4 = np.stack((im43, im42, im41), axis=2)
    next_batch = np.zeros([len(x_batch), img_size[0], img_size[1], 3])

    for i in range(num_batches):
        next_batch[(64 * i):(64 * i + 64), :, :,
                   0] = np.copy(channel1[i][:, :, :, 0])
        next_batch[(64 * i):(64 * i + 64), :, :,
                   1] = np.copy(channel2[i][:, :, :, 0])
        next_batch[(64 * i):(64 * i + 64), :, :,
                   2] = np.copy(channel3[i][:, :, :, 0])
    if (differnce != 0):
        next_batch[(64 * num_batches):(64 * num_batches + differnce), :, :,
                   0] = np.copy(channel1[num_batches][0:differnce, :, :, 0])
        next_batch[(64 * num_batches):(64 * num_batches + differnce), :, :,
                   1] = np.copy(channel2[num_batches][0:differnce, :, :, 0])
        next_batch[(64 * num_batches):(64 * num_batches + differnce), :, :,
                   2] = np.copy(channel3[num_batches][0:differnce, :, :, 0])
    return next_batch, im1, im2, im3, im4
def transfrom_data(NUM_TEST_PER_EPOCH):
    encoded_data=[]
    original_data=[]
    train_loader= get_data(BATCH_SIZE)
    trainiter = iter(train_loader)
    for i in range(NUM_ITERATION):
        features2, labels = next(trainiter)
        features2 = np.array(tf.keras.applications.resnet.preprocess_input(features2.data.numpy().transpose(0,2,3,1)*255))
        features=np.zeros([len(features2),img_size[0],img_size[1],img_size[2]])
        features[:,:,:,0]=(features2[:,:,:,0])
        features=np.clip(features+103.939,0,255)/255
        original_images=np.copy(features)
        for k in range(len(features)):
            features[k]=add_gaussian_noise(features[k], sd=np.random.uniform(NOISE_STD_RANGE[0], NOISE_STD_RANGE[1]))
        encoded_data.append(np.copy(features))
        original_data.append(np.copy(original_images))

    return encoded_data, original_data
def evaluate_checkpoint(filename):
    #sys.stdout = open(os.devnull, 'w')
    #Different graphs for all the models
    gx1 = tf.Graph()
    gx2 = tf.Graph()
    gx3 = tf.Graph()
    gx4 = tf.Graph()
    with tf.Session() as sess:
        # Restore the checkpoint
        saver.restore(sess, filename)

        # Iterate over the samples batch-by-batch
        #number of batches
        num_batches = int(math.ceil(num_eval_examples / eval_batch_size))

        x_blur_list4 = []
        x_adv_list4 = []
        #Storing y values

        train_loader = get_data(BATCH_SIZE)
        trainiter = iter(cycle(train_loader))
        for ibatch in range(num_batches):

            x_batch2, y_batch = next(trainiter)
            x_batch2 = np.array(x_batch2.data.numpy().transpose(0, 2, 3, 1))
            x_batch = np.zeros([len(x_batch2), img_size[0] * img_size[1]])

            x_batch3 = np.zeros([img_size[0], img_size[1]])
            x_batch3[:, :] = x_batch2[0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_NATURALX0", "JPEG")

            for i in range(len(x_batch2)):
                x_batch[i] = x_batch2[i].reshape([img_size[0] * img_size[1]])
            x_batch_adv = attack.perturb(x_batch, y_batch, sess)

            x_batch2 = np.zeros(
                [len(x_batch), img_size[0], img_size[1], img_size[2]])
            x_batch_adv2 = np.zeros(
                [len(x_batch), img_size[0], img_size[1], img_size[2]])
            for k in range(len(x_batch)):
                x_batch2[k] = add_gaussian_noise(
                    x_batch[k].reshape([img_size[0], img_size[1],
                                        img_size[2]]),
                    sd=np.random.uniform(NOISE_STD_RANGE[1],
                                         NOISE_STD_RANGE[1]))
                x_batch_adv2[k] = add_gaussian_noise(x_batch_adv[k].reshape(
                    [img_size[0], img_size[1], img_size[2]]),
                                                     sd=np.random.uniform(
                                                         NOISE_STD_RANGE[1],
                                                         NOISE_STD_RANGE[1]))

            x_blur_list4.append(x_batch2)
            x_adv_list4.append(x_batch_adv2)

            x_batch3 = np.zeros([img_size[0], img_size[1]])
            x_batch3[:, :] = x_blur_list4[0][0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_BLURX0", "JPEG")
            x_batch3 = np.zeros([img_size[0], img_size[1]])
            x_batch3[:, :] = x_adv_list4[0][0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_ADVX0", "JPEG")

    #Running through first autoencoder
    with gx1.as_default():
        with tf.Session() as sess2:
            sargan_model = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver = tf.train.Saver()
            sargan_saver = tf.train.import_meta_graph(trained_model_path +
                                                      '/sargan_mnist.meta')
            sargan_saver.restore(
                sess2, tf.train.latest_checkpoint(trained_model_path))
            for ibatch in range(num_batches):
                processed_batch = sess2.run(sargan_model.gen_img,
                                            feed_dict={
                                                sargan_model.image:
                                                x_adv_list4[ibatch],
                                                sargan_model.cond:
                                                x_adv_list4[ibatch]
                                            })
                x_adv_list4[ibatch] = processed_batch

                blurred_batch = sess2.run(sargan_model.gen_img,
                                          feed_dict={
                                              sargan_model.image:
                                              x_blur_list4[ibatch],
                                              sargan_model.cond:
                                              x_blur_list4[ibatch]
                                          })
                x_blur_list4[ibatch] = blurred_batch

            x_batch3[:, :] = x_blur_list4[0][0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_BLURX1", "JPEG")
            x_batch3 = np.zeros([img_size[0], img_size[1]])
            x_batch3[:, :] = x_adv_list4[0][0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_ADVX1", "JPEG")

    with gx2.as_default():
        with tf.Session() as sessx2:
            sargan_model2 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver2 = tf.train.Saver()
            sargan_saver2 = tf.train.import_meta_graph(trained_model_path2 +
                                                       '/sargan_mnist.meta')
            sargan_saver2.restore(
                sessx2, tf.train.latest_checkpoint(trained_model_path2))
            for ibatch in range(num_batches):
                processed_batch = sessx2.run(sargan_model2.gen_img,
                                             feed_dict={
                                                 sargan_model2.image:
                                                 x_adv_list4[ibatch],
                                                 sargan_model2.cond:
                                                 x_adv_list4[ibatch]
                                             })
                x_adv_list4[ibatch] = (processed_batch)

                blurred_batch = sessx2.run(sargan_model2.gen_img,
                                           feed_dict={
                                               sargan_model2.image:
                                               x_blur_list4[ibatch],
                                               sargan_model2.cond:
                                               x_blur_list4[ibatch]
                                           })
                x_blur_list4[ibatch] = (blurred_batch)
            x_batch3 = np.zeros([img_size[0], img_size[1]])
            x_batch3[:, :] = x_blur_list4[0][0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_BLURX2", "JPEG")
            x_batch3 = np.zeros([img_size[0], img_size[1]])
            x_batch3[:, :] = x_adv_list4[0][0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_ADVX2", "JPEG")

    with gx3.as_default():
        with tf.Session() as sessx3:
            sargan_model3 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver3 = tf.train.Saver()
            sargan_saver3 = tf.train.import_meta_graph(trained_model_path3 +
                                                       '/sargan_mnist.meta')
            sargan_saver3.restore(
                sessx3, tf.train.latest_checkpoint(trained_model_path3))
            for ibatch in range(num_batches):
                processed_batch = sessx3.run(sargan_model3.gen_img,
                                             feed_dict={
                                                 sargan_model3.image:
                                                 x_adv_list4[ibatch],
                                                 sargan_model3.cond:
                                                 x_adv_list4[ibatch]
                                             })
                x_adv_list4[ibatch] = processed_batch

                blurred_batch = sessx3.run(sargan_model3.gen_img,
                                           feed_dict={
                                               sargan_model3.image:
                                               x_blur_list4[ibatch],
                                               sargan_model3.cond:
                                               x_blur_list4[ibatch]
                                           })
                x_blur_list4[ibatch] = blurred_batch

            x_batch3 = np.zeros([img_size[0], img_size[1]])
            x_batch3[:, :] = x_blur_list4[0][0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_BLURX3", "JPEG")
            x_batch3 = np.zeros([img_size[0], img_size[1]])
            x_batch3[:, :] = x_adv_list4[0][0, :, :, 0] * 255
            nextimage = Image.fromarray(x_batch3.astype(np.uint8))
            nextimage.save("MNIST_ADVX3", "JPEG")

    #Final autoencoder setup
    with gx4.as_default():
        with tf.Session() as sessx4:
            sargan_model4 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver4 = tf.train.Saver()
            sargan_saver4 = tf.train.import_meta_graph(trained_model_path4 +
                                                       '/sargan_mnist.meta')
            sargan_saver4.restore(
                sessx4, tf.train.latest_checkpoint(trained_model_path4))
            for ibatch in range(num_batches):
                processed_batch = sessx4.run(sargan_model4.gen_img,
                                             feed_dict={
                                                 sargan_model4.image:
                                                 x_adv_list4[ibatch],
                                                 sargan_model4.cond:
                                                 x_adv_list4[ibatch]
                                             })
                x_adv_list4[ibatch] = processed_batch.reshape(
                    [len(x_batch), img_size[0] * img_size[1]])

                x_batch3 = np.zeros([img_size[0], img_size[1]])
                x_batch3[:, :] = processed_batch[0, :, :, 0] * 255
                nextimage = Image.fromarray(x_batch3.astype(np.uint8))
                nextimage.save("MNIST_ADVX4", "JPEG")

                blurred_batch = sessx4.run(sargan_model4.gen_img,
                                           feed_dict={
                                               sargan_model4.image:
                                               x_blur_list4[ibatch],
                                               sargan_model4.cond:
                                               x_blur_list4[ibatch]
                                           })
                x_blur_list4[ibatch] = blurred_batch.reshape(
                    [len(x_batch), img_size[0] * img_size[1]])

                x_batch3 = np.zeros([img_size[0], img_size[1]])
                x_batch3[:, :] = blurred_batch[0, :, :, 0] * 255
                nextimage = Image.fromarray(x_batch3.astype(np.uint8))
                nextimage.save("MNIST_BLURX4", "JPEG")
Beispiel #8
0
def main(args):
    image_number = 0
    model = SARGAN(img_size, BATCH_SIZE, img_channel=img_size[2])
    with tf.variable_scope("d_opt", reuse=tf.AUTO_REUSE):
        d_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(
            model.d_loss, var_list=model.d_vars)
    with tf.variable_scope("g_opt", reuse=tf.AUTO_REUSE):
        g_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(
            model.g_loss, var_list=model.g_vars)
    saver = tf.train.Saver(max_to_keep=20)

    gpu_options = tf.GPUOptions(allow_growth=True,
                                visible_device_list=str(GPU_ID))
    config = tf.ConfigProto(gpu_options=gpu_options)

    progress_bar = tqdm(range(MAX_EPOCH), unit="epoch")
    #list of loss values each item is the loss value of one ieteration
    train_d_loss_values = []
    train_g_loss_values = []

    #test_imgs, test_classes = get_data(test_filename)
    #imgs, classes = get_data(train_filename)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        #copies = imgs.astype('float32')
        #test_copies = test_imgs.astype('float32')
        for epoch in progress_bar:
            train_loader, val_loader = get_data_loader(BATCH_SIZE, BATCH_SIZE)
            counter = 0
            epoch_start_time = time.time()
            #shuffle(copies)
            #divide the images into equal sized batches
            #image_batches = np.array(list(chunks(copies, BATCH_SIZE)))
            trainiter = iter(train_loader)
            for i in range(NUM_ITERATION):
                #getting a batch from the training data
                #one_batch_of_imgs = image_batches[i]
                features2, labels = next(trainiter)
                features2 = features2.data.numpy().transpose(0, 2, 3, 1) * 255

                features = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                for i in range(len(features2)):
                    nextimage = Image.fromarray(
                        (features2[i]).astype(np.uint8))
                    nextimage = nextimage.convert('L')
                    features[i, :, :,
                             0] = np.array(nextimage, dtype='float32') / 255
                #copy the batch
                features_copy = features.copy()
                #corrupt the images
                corrupted_batch = np.array([
                    add_gaussian_noise(image,
                                       sd=np.random.uniform(
                                           NOISE_STD_RANGE[1],
                                           NOISE_STD_RANGE[1]))
                    for image in features_copy
                ])
                for i in range(len(corrupted_batch)):
                    corrupted_batch[i] = gaussian_filter(corrupted_batch[i],
                                                         sigma=1)
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: features,
                                    model.cond: corrupted_batch
                                })
                _, M = sess.run([g_opt, model.g_loss],
                                feed_dict={
                                    model.image: features,
                                    model.cond: corrupted_batch
                                })
                train_d_loss_values.append(m)
                train_g_loss_values.append(M)
                #print some notifications
                counter += 1
                if counter % 25 == 0:
                    print("\rEpoch [%d], Iteration [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, counter, time.time() - epoch_start_time, m, M))

            # save the trained network
            if epoch % SAVE_EVERY_EPOCH == 0:
                save_path = saver.save(sess,
                                       (trained_model_path + "/sargan_mnist"))
                print("\n\nModel saved in file: %s\n\n" % save_path)
            ''' +\
                                   "%s_model_%s.ckpt" % ( experiment_name, epoch+1))'''

            ##### TESTING FOR CURRUNT EPOCH
            testiter = iter(val_loader)
            NUM_TEST_PER_EPOCH = 1

            #test_batches = np.array(list(chunks(test_copies, BATCH_SIZE)))
            #test_images = test_batches[0]
            sum_psnr = 0
            list_images = []
            for j in range(NUM_TEST_PER_EPOCH):
                features2, labels = next(trainiter)
                features2 = features2.data.numpy().transpose(0, 2, 3, 1) * 255

                features = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                for i in range(len(features2)):
                    nextimage = Image.fromarray(
                        (features2[i]).astype(np.uint8))
                    nextimage = nextimage.convert('L')
                    features[i, :, :,
                             0] = np.array(nextimage, dtype='float32') / 255
                batch_copy = features.copy()
                #corrupt the images
                corrupted_batch = np.array([
                    add_gaussian_noise(image,
                                       sd=np.random.uniform(
                                           NOISE_STD_RANGE[0],
                                           NOISE_STD_RANGE[1]))
                    for image in batch_copy
                ])
                for i in range(len(corrupted_batch)):
                    corrupted_batch[i] = gaussian_filter(corrupted_batch[i],
                                                         sigma=1)

                gen_imgs = sess.run(model.gen_img,
                                    feed_dict={
                                        model.image: features,
                                        model.cond: corrupted_batch
                                    })
                print(features.shape, gen_imgs.shape)
                #if j %17 == 0: # only save 3 images 0, 17, 34
                list_images.append(
                    (features[0], corrupted_batch[0], gen_imgs[0]))
                list_images.append(
                    (features[17], corrupted_batch[17], gen_imgs[17]))
                list_images.append(
                    (features[34], corrupted_batch[34], gen_imgs[34]))
                for i in range(len(gen_imgs)):
                    current_img = features[i]
                    recovered_img = gen_imgs[i]
                    sum_psnr += ski_me.compare_psnr(current_img, recovered_img,
                                                    1)
                #psnr_value = ski_mem.compare_psnr(test_img, gen_img, 1)
                #sum_psnr += psnr_value
            average_psnr = sum_psnr / 50

            epoch_running_time = time.time() - epoch_start_time
            ############### SEND EMAIL ##############
            rows = 1
            cols = 3
            display_mean = np.array([0.485, 0.456, 0.406])
            display_std = np.array([0.229, 0.224, 0.225])
            if epoch % SAVE_EVERY_EPOCH == 0:
                #image = std * image + mean
                imgs_1 = list_images[0]
                imgs_2 = list_images[1]
                imgs_3 = list_images[2]
                imgs_1 = display_std * imgs_1 + display_mean
                imgs_2 = display_std * imgs_2 + display_mean
                imgs_3 = display_std * imgs_3 + display_mean
                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_1[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_1[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_1[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_1 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_1.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_1 = os.path.join(
                    output_path, 'image_%d_1.jpg' % image_number)
                plt.savefig(sample_test_file_1, dpi=300)

                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_2[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_2[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_2[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_2 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_2.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_2 = os.path.join(
                    output_path, 'image_%d_2.jpg' % image_number)
                plt.savefig(sample_test_file_2, dpi=300)

                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_3[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_3[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_3[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_3 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_3.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_3 = os.path.join(
                    output_path, 'image_%d_3.jpg' % image_number)
                image_number += 1
                plt.savefig(sample_test_file_3, dpi=300)
            plt.close("all")
def evaluate_checkpoint(filename):
    #sys.stdout = open(os.devnull, 'w')
    #Different graphs for all the models
    gx1 = tf.Graph()
    gx2 = tf.Graph()
    gx3 = tf.Graph()
    gx4 = tf.Graph()
    g3 = tf.Graph()
    with tf.Session() as sess:
        # Restore the checkpoint
        saver.restore(sess, filename)

        # Iterate over the samples batch-by-batch
        #number of batches
        num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) - 1
        total_xent_nat = 0.
        total_xent_corr = 0.
        total_corr_nat = 0.
        total_corr_corr = 0.

        total_corr_adv = np.zeros([4]).astype(dtype='float32')
        total_corr_blur = np.zeros([4]).astype(dtype='float32')
        total_xent_adv = np.zeros([4]).astype(dtype='float32')
        total_xent_blur = np.zeros([4]).astype(dtype='float32')

        #storing the various images
        x_batch_list = []
        x_corr_list = []
        x_blur_list1 = []
        x_adv_list1 = []
        x_blur_list2 = []
        x_adv_list2 = []
        x_blur_list3 = []
        x_adv_list3 = []
        x_blur_list4 = []
        x_adv_list4 = []
        #Storing y values
        y_batch_list = []

        train_loader = get_data(BATCH_SIZE)
        trainiter = iter(cycle(train_loader))
        for ibatch in range(num_batches):

            x_batch2, y_batch = next(trainiter)
            y_batch_list.append(y_batch)
            x_batch2 = np.array(x_batch2.data.numpy().transpose(0, 2, 3, 1))
            x_batch = np.zeros([len(x_batch2), img_size[0] * img_size[1]])
            for i in range(len(x_batch2)):
                x_batch[i] = x_batch2[i].reshape([img_size[0] * img_size[1]])
            x_batch_adv = attack.perturb(x_batch, y_batch, sess)

            x_batch2 = np.zeros(
                [len(x_batch), img_size[0], img_size[1], img_size[2]])
            x_batch_adv2 = np.zeros(
                [len(x_batch), img_size[0], img_size[1], img_size[2]])
            for k in range(len(x_batch)):
                x_batch2[k] = add_gaussian_noise(
                    x_batch[k].reshape([img_size[0], img_size[1],
                                        img_size[2]]),
                    sd=np.random.uniform(NOISE_STD_RANGE[1],
                                         NOISE_STD_RANGE[1]))
                x_batch_adv2[k] = add_gaussian_noise(x_batch_adv[k].reshape(
                    [img_size[0], img_size[1], img_size[2]]),
                                                     sd=np.random.uniform(
                                                         NOISE_STD_RANGE[1],
                                                         NOISE_STD_RANGE[1]))

            x_batch_list.append(x_batch)
            x_corr_list.append(x_batch_adv)
            x_blur_list4.append(x_batch2)
            x_adv_list4.append(x_batch_adv2)

    #Running through first autoencoder
    with gx1.as_default():
        with tf.Session() as sess2:
            sargan_model = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver = tf.train.Saver()
            sargan_saver = tf.train.import_meta_graph(trained_model_path +
                                                      '/sargan_mnist.meta')
            sargan_saver.restore(
                sess2, tf.train.latest_checkpoint(trained_model_path))
            for ibatch in range(num_batches):
                processed_batch = sess2.run(sargan_model.gen_img,
                                            feed_dict={
                                                sargan_model.image:
                                                x_adv_list4[ibatch],
                                                sargan_model.cond:
                                                x_adv_list4[ibatch]
                                            })
                x_adv_list4[ibatch] = processed_batch

                blurred_batch = sess2.run(sargan_model.gen_img,
                                          feed_dict={
                                              sargan_model.image:
                                              x_blur_list4[ibatch],
                                              sargan_model.cond:
                                              x_blur_list4[ibatch]
                                          })
                x_blur_list4[ibatch] = blurred_batch

                #adding images to first autoencoder data set
                x_blur_list1.append(
                    blurred_batch.reshape(
                        [len(x_batch), img_size[0] * img_size[1]]))
                x_adv_list1.append(
                    processed_batch.reshape(
                        [len(x_batch), img_size[0] * img_size[1]]))
            psnr = 0
            for jj in range(num_batches):
                next_psnr = 0
                psnr_value = (tf.image.psnr(np.array(x_batch_list[jj]).reshape(
                    [64, img_size[0], img_size[1], img_size[2]]),
                                            np.array(x_adv_list4[jj]).reshape([
                                                64, img_size[0], img_size[1],
                                                img_size[2]
                                            ]),
                                            max_val=1))
                psnr_value = sess2.run(psnr_value)
                for i in range(64):
                    next_psnr += psnr_value[i]
                psnr += next_psnr / 64
            psnr /= num_batches
    print("Not actual PSNR= ", psnr, "\n\n\n")
    with gx2.as_default():
        with tf.Session() as sessx2:
            sargan_model2 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver2 = tf.train.Saver()
            sargan_saver2 = tf.train.import_meta_graph(trained_model_path2 +
                                                       '/sargan_mnist.meta')
            sargan_saver2.restore(
                sessx2, tf.train.latest_checkpoint(trained_model_path2))
            for ibatch in range(num_batches):
                processed_batch = sessx2.run(sargan_model2.gen_img,
                                             feed_dict={
                                                 sargan_model2.image:
                                                 x_adv_list4[ibatch],
                                                 sargan_model2.cond:
                                                 x_adv_list4[ibatch]
                                             })
                x_adv_list4[ibatch] = (processed_batch)

                blurred_batch = sessx2.run(sargan_model2.gen_img,
                                           feed_dict={
                                               sargan_model2.image:
                                               x_blur_list4[ibatch],
                                               sargan_model2.cond:
                                               x_blur_list4[ibatch]
                                           })
                x_blur_list4[ibatch] = (blurred_batch)

                #adding images to second autoencoder data set
                x_blur_list2.append(
                    blurred_batch.reshape(
                        [len(x_batch), img_size[0] * img_size[1]]))
                x_adv_list2.append(
                    processed_batch.reshape(
                        [len(x_batch), img_size[0] * img_size[1]]))
    with gx3.as_default():
        with tf.Session() as sessx3:
            sargan_model3 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver3 = tf.train.Saver()
            sargan_saver3 = tf.train.import_meta_graph(trained_model_path3 +
                                                       '/sargan_mnist.meta')
            sargan_saver3.restore(
                sessx3, tf.train.latest_checkpoint(trained_model_path3))
            for ibatch in range(num_batches):
                processed_batch = sessx3.run(sargan_model3.gen_img,
                                             feed_dict={
                                                 sargan_model3.image:
                                                 x_adv_list4[ibatch],
                                                 sargan_model3.cond:
                                                 x_adv_list4[ibatch]
                                             })
                x_adv_list4[ibatch] = processed_batch

                blurred_batch = sessx3.run(sargan_model3.gen_img,
                                           feed_dict={
                                               sargan_model3.image:
                                               x_blur_list4[ibatch],
                                               sargan_model3.cond:
                                               x_blur_list4[ibatch]
                                           })
                x_blur_list4[ibatch] = blurred_batch

                #adding images to third autoencoder data set
                x_blur_list3.append(
                    blurred_batch.reshape(
                        [len(x_batch), img_size[0] * img_size[1]]))
                x_adv_list3.append(
                    processed_batch.reshape(
                        [len(x_batch), img_size[0] * img_size[1]]))
    #Final autoencoder setup
    with gx4.as_default():
        with tf.Session() as sessx4:
            sargan_model4 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver4 = tf.train.Saver()
            sargan_saver4 = tf.train.import_meta_graph(trained_model_path4 +
                                                       '/sargan_mnist.meta')
            sargan_saver4.restore(
                sessx4, tf.train.latest_checkpoint(trained_model_path4))
            for ibatch in range(num_batches):
                processed_batch = sessx4.run(sargan_model4.gen_img,
                                             feed_dict={
                                                 sargan_model4.image:
                                                 x_adv_list4[ibatch],
                                                 sargan_model4.cond:
                                                 x_adv_list4[ibatch]
                                             })
                x_adv_list4[ibatch] = processed_batch.reshape(
                    [len(x_batch), img_size[0] * img_size[1]])

                blurred_batch = sessx4.run(sargan_model4.gen_img,
                                           feed_dict={
                                               sargan_model4.image:
                                               x_blur_list4[ibatch],
                                               sargan_model4.cond:
                                               x_blur_list4[ibatch]
                                           })
                x_blur_list4[ibatch] = blurred_batch.reshape(
                    [len(x_batch), img_size[0] * img_size[1]])
            psnr = 0
            for jj in range(num_batches):
                next_psnr = 0
                psnr_value = (tf.image.psnr(np.array(x_batch_list[jj]).reshape(
                    [64, img_size[0], img_size[1], img_size[2]]),
                                            np.array(x_adv_list4[jj]).reshape([
                                                64, img_size[0], img_size[1],
                                                img_size[2]
                                            ]),
                                            max_val=1))
                psnr_value = sessx4.run(psnr_value)
                for i in range(64):
                    next_psnr += psnr_value[i]
                psnr += next_psnr / 64
            psnr /= num_batches
            print("\n\nPSNR= ", psnr, "\n\n")

    with g3.as_default():
        model3 = Model()
        saver2 = tf.train.Saver()
        with tf.Session() as sess3:
            saver2.restore(sess3, filename)
            for ibatch in range(num_batches):
                cur_xent_adv = np.zeros([4]).astype(dtype='float32')
                cur_xent_blur = np.zeros([4]).astype(dtype='float32')
                cur_corr_adv = np.zeros([4]).astype(dtype='float32')
                cur_corr_blur = np.zeros([4]).astype(dtype='float32')

                dict_nat = {
                    model3.x_input: x_batch_list[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                dict_corr = {
                    model3.x_input: x_corr_list[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                #First autoencoder dictionary
                dict_adv1 = {
                    model3.x_input: x_adv_list1[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                dict_blur1 = {
                    model3.x_input: x_blur_list1[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                #Second autoencoder dictionary
                dict_adv2 = {
                    model3.x_input: x_adv_list2[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                dict_blur2 = {
                    model3.x_input: x_blur_list2[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                #Third autoencoder dictionary
                dict_adv3 = {
                    model3.x_input: x_adv_list3[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                dict_blur3 = {
                    model3.x_input: x_blur_list3[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                #Fourth autoencoder dictionary
                dict_adv4 = {
                    model3.x_input: x_adv_list4[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                dict_blur4 = {
                    model3.x_input: x_blur_list4[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                #Regular Images
                cur_corr_nat, cur_xent_nat = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_nat)

                cur_corr_corr, cur_xent_corr = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_corr)

                #First autoencoder dictionary
                cur_corr_blur[0], cur_xent_blur[0] = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_blur1)

                cur_corr_adv[0], cur_xent_adv[0] = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_adv1)

                #Second autoencoder dictionary
                cur_corr_blur[1], cur_xent_blur[1] = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_blur2)

                cur_corr_adv[1], cur_xent_adv[1] = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_adv2)

                #Third autoencoder dictionary
                cur_corr_blur[2], cur_xent_blur[2] = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_blur3)

                cur_corr_adv[2], cur_xent_adv[2] = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_adv3)

                #Fourth autoencoder dictionary
                cur_corr_blur[3], cur_xent_blur[3] = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_blur4)

                cur_corr_adv[3], cur_xent_adv[3] = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_adv4)

                #Natural
                total_corr_nat += cur_corr_nat
                total_corr_corr += cur_corr_corr
                total_xent_nat += cur_xent_nat
                total_xent_corr += cur_xent_corr

                #running accuracy
                total_corr_adv += cur_corr_adv
                total_corr_blur += cur_corr_blur
                total_xent_adv += cur_xent_adv
                total_xent_blur += cur_xent_blur

            #Regual images
            avg_xent_nat = total_xent_nat / num_eval_examples
            avg_xent_corr = total_xent_corr / num_eval_examples
            acc_nat = total_corr_nat / num_eval_examples
            acc_corr = total_corr_corr / num_eval_examples

            #Total accuracy
            acc_adv = total_corr_adv / num_eval_examples
            acc_blur = total_corr_blur / num_eval_examples
            avg_xent_adv = total_xent_adv / num_eval_examples
            avg_xent_blur = total_xent_blur / num_eval_examples

    #sys.stdout = sys.__stdout__
    print("No Autoencoder")
    print('natural: {:.2f}%'.format(100 * acc_nat))
    print('Corrupted: {:.2f}%'.format(100 * acc_corr))
    print('avg nat loss: {:.4f}'.format(avg_xent_nat))
    print('avg corr loss: {:.4f} \n'.format(avg_xent_corr))

    print("First Autoencoder")
    print('natural with blur: {:.2f}%'.format(100 * acc_blur[0]))
    print('adversarial: {:.2f}%'.format(100 * acc_adv[0]))
    print('avg nat with blur loss: {:.4f}'.format(avg_xent_blur[0]))
    print('avg adv loss: {:.4f} \n'.format(avg_xent_adv[0]))

    print("Second Autoencoder")
    print('natural with blur: {:.2f}%'.format(100 * acc_blur[1]))
    print('adversarial: {:.2f}%'.format(100 * acc_adv[1]))
    print('avg nat with blur loss: {:.4f}'.format(avg_xent_blur[1]))
    print('avg adv loss: {:.4f} \n'.format(avg_xent_adv[1]))

    print("Third Autoencoder")
    print('natural with blur: {:.2f}%'.format(100 * acc_blur[2]))
    print('adversarial: {:.2f}%'.format(100 * acc_adv[2]))
    print('avg nat with blur loss: {:.4f}'.format(avg_xent_blur[2]))
    print('avg adv loss: {:.4f} \n'.format(avg_xent_adv[2]))

    print("Fourth Autoencoder")
    print('natural with blur: {:.2f}%'.format(100 * acc_blur[3]))
    print('adversarial: {:.2f}%'.format(100 * acc_adv[3]))
    print('avg nat with blur loss: {:.4f}'.format(avg_xent_blur[3]))
    print('avg adv loss: {:.4f} \n'.format(avg_xent_adv[3]))
def get_processed_data(x_batch, img_size, GPU_ID):
    NOISE_STD_RANGE = [0.0, 0.02]
    x_batch = np.expand_dims(
        add_gaussian_noise(x_batch, sd=NOISE_STD_RANGE[1]), 4)
    channel1 = []
    channel2 = []
    channel3 = []
    num_batches = math.floor(len(x_batch) / 64)
    differnce = len(x_batch) - num_batches * 64
    num_batches2 = num_batches
    batch1 = np.zeros([64, img_size[0], img_size[1], img_size[2]])
    for i in range(num_batches):
        channel1.append(np.copy(x_batch[(64 * i):(64 * i + 64), :, :, 0]))
        channel2.append(np.copy(x_batch[(64 * i):(64 * i + 64), :, :, 1]))
        channel3.append(np.copy(x_batch[(64 * i):(64 * i + 64), :, :, 2]))

    if (differnce != 0):
        batch1[0:differnce, :, :, :] = x_batch[(64 * num_batches):(
            64 * num_batches + differnce), :, :, 0]
        channel1.append(np.copy(batch1))
        batch1[0:differnce, :, :, :] = x_batch[(64 * num_batches):(
            64 * num_batches + differnce), :, :, 1]
        channel2.append(np.copy(batch1))
        batch1[0:differnce, :, :, :] = x_batch[(64 * num_batches):(
            64 * num_batches + differnce), :, :, 2]
        channel3.append(np.copy(batch1))
        num_batches2 += 1
    blur_batch11, blur_batch21, blur_batch31, blur_batch41 = get_channel1(
        channel1, img_size, num_batches2, GPU_ID)
    blur_batch12, blur_batch22, blur_batch32, blur_batch42 = get_channel2(
        channel2, img_size, num_batches2, GPU_ID)
    blur_batch13, blur_batch23, blur_batch33, blur_batch43 = get_channel3(
        channel3, img_size, num_batches2, GPU_ID)
    next_batch1 = np.zeros([len(x_batch), img_size[0], img_size[1], 3])
    next_batch2 = np.zeros([len(x_batch), img_size[0], img_size[1], 3])
    next_batch3 = np.zeros([len(x_batch), img_size[0], img_size[1], 3])
    next_batch4 = np.zeros([len(x_batch), img_size[0], img_size[1], 3])
    for i in range(num_batches):
        next_batch1[(64 * i):(64 * i + 64), :, :,
                    0] = np.copy(blur_batch11[i][:, :, :, 0])
        next_batch1[(64 * i):(64 * i + 64), :, :,
                    1] = np.copy(blur_batch12[i][:, :, :, 0])
        next_batch1[(64 * i):(64 * i + 64), :, :,
                    2] = np.copy(blur_batch13[i][:, :, :, 0])

        next_batch2[(64 * i):(64 * i + 64), :, :,
                    0] = np.copy(blur_batch21[i][:, :, :, 0])
        next_batch2[(64 * i):(64 * i + 64), :, :,
                    1] = np.copy(blur_batch22[i][:, :, :, 0])
        next_batch2[(64 * i):(64 * i + 64), :, :,
                    2] = np.copy(blur_batch23[i][:, :, :, 0])

        next_batch3[(64 * i):(64 * i + 64), :, :,
                    0] = np.copy(blur_batch31[i][:, :, :, 0])
        next_batch3[(64 * i):(64 * i + 64), :, :,
                    1] = np.copy(blur_batch32[i][:, :, :, 0])
        next_batch3[(64 * i):(64 * i + 64), :, :,
                    2] = np.copy(blur_batch33[i][:, :, :, 0])

        next_batch4[(64 * i):(64 * i + 64), :, :,
                    0] = np.copy(blur_batch41[i][:, :, :, 0])
        next_batch4[(64 * i):(64 * i + 64), :, :,
                    1] = np.copy(blur_batch42[i][:, :, :, 0])
        next_batch4[(64 * i):(64 * i + 64), :, :,
                    2] = np.copy(blur_batch43[i][:, :, :, 0])
    if (differnce != 0):
        next_batch1[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    0] = np.copy(blur_batch11[num_batches][0:differnce, :, :,
                                                           0])
        next_batch1[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    1] = np.copy(blur_batch12[num_batches][0:differnce, :, :,
                                                           0])
        next_batch1[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    2] = np.copy(blur_batch13[num_batches][0:differnce, :, :,
                                                           0])

        next_batch2[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    0] = np.copy(blur_batch21[num_batches][0:differnce, :, :,
                                                           0])
        next_batch2[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    1] = np.copy(blur_batch22[num_batches][0:differnce, :, :,
                                                           0])
        next_batch2[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    2] = np.copy(blur_batch23[num_batches][0:differnce, :, :,
                                                           0])

        next_batch3[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    0] = np.copy(blur_batch31[num_batches][0:differnce, :, :,
                                                           0])
        next_batch3[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    1] = np.copy(blur_batch32[num_batches][0:differnce, :, :,
                                                           0])
        next_batch3[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    2] = np.copy(blur_batch33[num_batches][0:differnce, :, :,
                                                           0])

        next_batch4[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    0] = np.copy(blur_batch41[num_batches][0:differnce, :, :,
                                                           0])
        next_batch4[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    1] = np.copy(blur_batch42[num_batches][0:differnce, :, :,
                                                           0])
        next_batch4[(64 * num_batches):(64 * num_batches + differnce), :, :,
                    2] = np.copy(blur_batch43[num_batches][0:differnce, :, :,
                                                           0])
    return next_batch1, next_batch2, next_batch3, next_batch4
def main(args):
    image_number = 0
    model = SARGAN(img_size, BATCH_SIZE, img_channel=1)
    with tf.variable_scope("d_opt", reuse=tf.AUTO_REUSE):
        d_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(
            model.d_loss, var_list=model.d_vars)
    with tf.variable_scope("g_opt", reuse=tf.AUTO_REUSE):
        g_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(
            model.g_loss, var_list=model.g_vars)
    saver = tf.train.Saver(max_to_keep=20)

    gpu_options = tf.GPUOptions(allow_growth=True,
                                visible_device_list=str(GPU_ID))
    config = tf.ConfigProto(gpu_options=gpu_options)

    progress_bar = tqdm(range(MAX_EPOCH), unit="epoch")
    #list of loss values each item is the loss value of one ieteration
    train_d_loss_values = []
    train_g_loss_values = []

    #test_imgs, test_classes = get_data(test_filename)
    #imgs, classes = get_data(train_filename)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        #copies = imgs.astype('float32')
        #test_copies = test_imgs.astype('float32')
        for epoch in progress_bar:
            train_loader, val_loader = get_data_loader(BATCH_SIZE, BATCH_SIZE)
            counter = 0
            epoch_start_time = time.time()
            #shuffle(copies)
            #divide the images into equal sized batches
            #image_batches = np.array(list(chunks(copies, BATCH_SIZE)))
            trainiter = iter(train_loader)
            for i in range(NUM_ITERATION):
                #getting a batch from the training data
                #one_batch_of_imgs = image_batches[i]
                features, labels = next(trainiter)
                features = features.data.numpy().transpose(0, 2, 3, 1)
                #copy the batch
                features_copy = features.copy()
                #corrupt the images
                corrupted_batch = np.array([
                    add_gaussian_noise(image / 1.5,
                                       sd=np.random.uniform(
                                           NOISE_STD_RANGE[0],
                                           NOISE_STD_RANGE[1]))
                    for image in features_copy
                ])
                for i in range(len(corrupted_batch)):
                    corrupted_batch[i] = gaussian_filter(corrupted_batch[i],
                                                         sigma=1)
                corrupted_batch = np.array([
                    add_gaussian_noise(image,
                                       sd=np.random.uniform(
                                           NOISE_STD_RANGE[0],
                                           NOISE_STD_RANGE[1] / 2))
                    for image in corrupted_batch
                ])
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: features,
                                    model.cond: corrupted_batch
                                })
                _, M = sess.run([g_opt, model.g_loss],
                                feed_dict={
                                    model.image: features,
                                    model.cond: corrupted_batch
                                })
                train_d_loss_values.append(m)
                train_g_loss_values.append(M)
                #print some notifications
                counter += 1
                if counter % 25 == 0:
                    print("\rEpoch [%d], Iteration [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, counter, time.time() - epoch_start_time, m, M))

            # save the trained network
            if epoch % SAVE_EVERY_EPOCH == 0:
                save_path = saver.save(sess,
                                       (trained_model_path + "/sargan_mnist"))
                print("\n\nModel saved in file: %s\n\n" % save_path)
            ''' +\
                                   "%s_model_%s.ckpt" % ( experiment_name, epoch+1))'''

            ##### TESTING FOR CURRUNT EPOCH
            testiter = iter(val_loader)
            NUM_TEST_PER_EPOCH = 1

            #test_batches = np.array(list(chunks(test_copies, BATCH_SIZE)))
            #test_images = test_batches[0]
            sum_psnr = 0
            list_images = []
            for j in range(NUM_TEST_PER_EPOCH):
                features, labels = next(testiter)
                features = features.data.numpy()
                features = features.transpose(0, 2, 3, 1)
                batch_copy = features.copy()
                #corrupt the images
                corrupted_batch = np.array([
                    add_gaussian_noise(image / 1.5,
                                       sd=np.random.uniform(
                                           NOISE_STD_RANGE[0],
                                           NOISE_STD_RANGE[1]))
                    for image in batch_copy
                ])
                for i in range(len(corrupted_batch)):
                    corrupted_batch[i] = gaussian_filter(corrupted_batch[i],
                                                         sigma=1)
                corrupted_batch = np.array([
                    add_gaussian_noise(image,
                                       sd=np.random.uniform(
                                           NOISE_STD_RANGE[0],
                                           NOISE_STD_RANGE[1] / 2))
                    for image in corrupted_batch
                ])

                gen_imgs = sess.run(model.gen_img,
                                    feed_dict={
                                        model.image: features,
                                        model.cond: corrupted_batch
                                    })
                #if j %17 == 0: # only save 3 images 0, 17, 34
                list_images.append(
                    (features[0], corrupted_batch[0], gen_imgs[0]))
                list_images.append(
                    (features[17], corrupted_batch[17], gen_imgs[17]))
                list_images.append(
                    (features[34], corrupted_batch[34], gen_imgs[34]))
                for i in range(len(gen_imgs)):
                    current_img = features[i]
                    recovered_img = gen_imgs[i]
                    sum_psnr += ski_me.compare_psnr(current_img, recovered_img,
                                                    2)
                #psnr_value = ski_mem.compare_psnr(test_img, gen_img, 1)
                #sum_psnr += psnr_value
            average_psnr = sum_psnr / 50

            epoch_running_time = time.time() - epoch_start_time
            ############### SEND EMAIL ##############
            rows = 1
            cols = 3
            display_mean = np.array([0, 0, 0])
            display_std = np.array([1, 1, 1])
            if epoch % SAVE_EVERY_EPOCH == 0:
                #image = std * image + mean
                imgs_1 = list_images[0]
                imgs_2 = list_images[1]
                imgs_3 = list_images[2]
                imgs_1 = display_std * imgs_1 + display_mean
                imgs_2 = display_std * imgs_2 + display_mean
                imgs_3 = display_std * imgs_3 + display_mean
                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_1[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_1[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_1[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_1 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_1.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_1 = os.path.join(
                    output_path, 'image_%d_1.jpg' % image_number)
                plt.savefig(sample_test_file_1, dpi=300)

                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_2[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_2[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_2[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_2 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_2.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_2 = os.path.join(
                    output_path, 'image_%d_2.jpg' % image_number)
                plt.savefig(sample_test_file_2, dpi=300)

                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_3[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_3[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_3[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_3 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_3.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_3 = os.path.join(
                    output_path, 'image_%d_3.jpg' % image_number)
                image_number += 1
                plt.savefig(sample_test_file_3, dpi=300)
            '''attachments = [sample_test_file_1, sample_test_file_2, sample_test_file_3]

            email_alert_subject = "Epoch %s %s SARGAN" % (epoch+1, experiment_name.upper())
            
            email_alert_text = """
            Epoch [{}/{}]
            EXPERIMENT PARAMETERS:
                        Learning rate: {}
                        Batch size: {}
                        Max epoch number: {}
                        Training iterations each epoch: {}
                        noise standard deviation: {}
                        Running time: {:.2f} [s]
            AVERAGE PSNR VALUE ON 50 TEST IMAGES: {}
            """.format(epoch + 1, MAX_EPOCH, LEARNING_RATE, BATCH_SIZE, MAX_EPOCH, 
                      NUM_ITERATION, ((NOISE_STD_RANGE[0] + NOISE_STD_RANGE[1])/2), epoch_running_time, average_psnr)
            if epoch % SAVE_EVERY_EPOCH == 0: 
                send_images_via_email(email_alert_subject,
                 email_alert_text,
                 attachments,
                 sender_email="*****@*****.**", 
                 recipient_emails=["*****@*****.**"])'''
            plt.close("all")
def evaluate_checkpoint(filename):
    #sys.stdout = open(os.devnull, 'w')
    g2 = tf.Graph()
    gx2 = tf.Graph()
    gx3 = tf.Graph()
    gx4 = tf.Graph()
    g3 = tf.Graph()
    loop_list_adv = np.zeros([2, number_of_runs])
    loop_list_auto = np.zeros([2, number_of_runs])
    epsilonc = starting_pert
    with tf.Session() as sess:
        # Restore the checkpoint
        saver.restore(sess, tf.train.latest_checkpoint(model_dir))

        # Iterate over the samples batch-by-batch
        num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
        x_corr_list = []
        x_adv_list = []
        y_batch_list = []

        train_loader = get_data(BATCH_SIZE)
        trainiter = iter(cycle(train_loader))
        for ibatch in range(num_batches):

            x_batch2, y_batch = next(trainiter)
            y_batch_list.append(y_batch)
            x_batch2 = np.array(x_batch2.data.numpy().transpose(0, 2, 3, 1))
            x_batch = np.zeros([len(x_batch2), img_size[0] * img_size[1]])
            for i in range(len(x_batch2)):
                x_batch[i] = x_batch2[i].reshape([img_size[0] * img_size[1]])
            x_batch_adv = attacks[ibatch].perturb(x_batch, y_batch, sess)
            x_batch_adv2 = np.zeros(
                [len(x_batch), img_size[0], img_size[1], img_size[2]])
            for k in range(len(x_batch)):
                x_batch_adv2[k] = add_gaussian_noise(x_batch_adv[k].reshape(
                    [img_size[0], img_size[1], img_size[2]]),
                                                     sd=np.random.uniform(
                                                         NOISE_STD_RANGE[1],
                                                         NOISE_STD_RANGE[1]))
            x_corr_list.append(x_batch_adv)
            x_adv_list.append(x_batch_adv2)

    with g2.as_default():
        with tf.Session() as sess2:
            sargan_model = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver = tf.train.Saver()
            sargan_saver = tf.train.import_meta_graph(trained_model_path +
                                                      '/sargan_mnist.meta')
            sargan_saver.restore(
                sess2, tf.train.latest_checkpoint(trained_model_path))
            for ibatch in range(num_batches):
                processed_batch = sess2.run(sargan_model.gen_img,
                                            feed_dict={
                                                sargan_model.image:
                                                x_adv_list[ibatch],
                                                sargan_model.cond:
                                                x_adv_list[ibatch]
                                            })
                x_adv_list[ibatch] = processed_batch
    with gx2.as_default():
        with tf.Session() as sessx2:
            sargan_model2 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver2 = tf.train.Saver()
            sargan_saver2 = tf.train.import_meta_graph(trained_model_path2 +
                                                       '/sargan_mnist.meta')
            sargan_saver2.restore(
                sessx2, tf.train.latest_checkpoint(trained_model_path2))
            for ibatch in range(num_batches):
                processed_batch = sessx2.run(sargan_model2.gen_img,
                                             feed_dict={
                                                 sargan_model2.image:
                                                 x_adv_list[ibatch],
                                                 sargan_model2.cond:
                                                 x_adv_list[ibatch]
                                             })
                x_adv_list[ibatch] = (processed_batch)
    with gx3.as_default():
        with tf.Session() as sessx3:
            sargan_model3 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver3 = tf.train.Saver()
            sargan_saver3 = tf.train.import_meta_graph(trained_model_path3 +
                                                       '/sargan_mnist.meta')
            sargan_saver3.restore(
                sessx3, tf.train.latest_checkpoint(trained_model_path3))
            for ibatch in range(num_batches):
                processed_batch = sessx3.run(sargan_model3.gen_img,
                                             feed_dict={
                                                 sargan_model3.image:
                                                 x_adv_list[ibatch],
                                                 sargan_model3.cond:
                                                 x_adv_list[ibatch]
                                             })
                x_adv_list[ibatch] = processed_batch
    with gx4.as_default():
        with tf.Session() as sessx4:
            sargan_model4 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver4 = tf.train.Saver()
            sargan_saver4 = tf.train.import_meta_graph(trained_model_path4 +
                                                       '/sargan_mnist.meta')
            sargan_saver4.restore(
                sessx4, tf.train.latest_checkpoint(trained_model_path4))
            for ibatch in range(num_batches):
                processed_batch = sessx4.run(sargan_model4.gen_img,
                                             feed_dict={
                                                 sargan_model4.image:
                                                 x_adv_list[ibatch],
                                                 sargan_model4.cond:
                                                 x_adv_list[ibatch]
                                             })
                x_adv_list[ibatch] = processed_batch.reshape(
                    [len(x_batch), img_size[0] * img_size[1]])
    with g3.as_default():
        model3 = Model()
        saver2 = tf.train.Saver()
        with tf.Session() as sess3:
            saver2.restore(sess3, filename)
            for ibatch in range(num_batches):

                dict_corr = {
                    model3.x_input: x_corr_list[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                dict_adv = {
                    model3.x_input: x_adv_list[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                cur_corr_corr, cur_xent_corr = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_corr)

                cur_corr_adv, cur_xent_adv = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_adv)

                loop_list_adv[0, ibatch] = epsilonc
                loop_list_adv[1, ibatch] = cur_corr_adv / eval_batch_size
                loop_list_auto[0, ibatch] = epsilonc
                loop_list_auto[1, ibatch] = cur_corr_corr / eval_batch_size
                epsilonc += change
            '''summary = tf.Summary(value=[
              tf.Summary.Value(tag='xent adv eval', simple_value= avg_xent_adv),
              tf.Summary.Value(tag='xent adv', simple_value= avg_xent_adv),
              tf.Summary.Value(tag='xent nat', simple_value= avg_xent_nat),
              tf.Summary.Value(tag='accuracy adv eval', simple_value= acc_adv),
              tf.Summary.Value(tag='accuracy adv', simple_value= acc_adv),
              tf.Summary.Value(tag='accuracy nat', simple_value= acc_nat)])
            summary_writer.add_summary(summary, global_step.eval(sess3))'''
    #sys.stdout = sys.__stdout__
    return loop_list_adv, loop_list_auto