def transfrom_data(NUM_TEST_PER_EPOCH):
    encoded_data = []
    original_data = []
    train_loader = get_data(BATCH_SIZE)
    trainiter = iter(train_loader)
    for i in range(NUM_ITERATION):
        features2, labels = next(trainiter)
        features2 = np.array(
            tf.keras.applications.resnet.preprocess_input(
                features2.data.numpy().transpose(0, 2, 3, 1) * 255))
        features = np.zeros(
            [len(features2), img_size[0], img_size[1], img_size[2]])
        features[:, :, :, 0] = (features2[:, :, :, 2])
        features = np.clip(features + 123.68, 0, 255) / 255
        original_images = np.copy(features)
        for k in range(len(features)):
            features[k] = add_gaussian_noise(features[k],
                                             sd=np.random.uniform(
                                                 NOISE_STD_RANGE[0],
                                                 NOISE_STD_RANGE[1]))
        encoded_data.append(np.copy(features))
        original_data.append(np.copy(original_images))

    gpu_options = tf.GPUOptions(allow_growth=True,
                                visible_device_list=str(GPU_ID))
    config = tf.ConfigProto(gpu_options=gpu_options)
    gx1 = tf.Graph()
    with gx1.as_default():
        with tf.Session(config=config) as sess1:
            sargan_model1 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver1 = tf.train.Saver()
            sargan_saver1 = tf.train.import_meta_graph(trained_model_path1 +
                                                       '/sargan_mnist.meta')
            sargan_saver1.restore(
                sess1, tf.train.latest_checkpoint(trained_model_path1))
            for ibatch in range(NUM_TEST_PER_EPOCH):
                processed_batch = sess1.run(sargan_model1.gen_img,
                                            feed_dict={
                                                sargan_model1.image:
                                                encoded_data[ibatch],
                                                sargan_model1.cond:
                                                encoded_data[ibatch]
                                            })
                encoded_data[ibatch] = (np.copy(processed_batch))

    gx2 = tf.Graph()
    with gx2.as_default():
        with tf.Session(config=config) as sess2:
            sargan_model2 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver2 = tf.train.Saver()
            sargan_saver2 = tf.train.import_meta_graph(trained_model_path2 +
                                                       '/sargan_mnist.meta')
            sargan_saver2.restore(
                sess2, tf.train.latest_checkpoint(trained_model_path2))
            for ibatch in range(NUM_TEST_PER_EPOCH):
                processed_batch = sess2.run(sargan_model2.gen_img,
                                            feed_dict={
                                                sargan_model2.image:
                                                encoded_data[ibatch],
                                                sargan_model2.cond:
                                                encoded_data[ibatch]
                                            })
                encoded_data[ibatch] = (np.copy(processed_batch))

    gx3 = tf.Graph()
    with gx3.as_default():
        with tf.Session(config=config) as sess3:
            sargan_model3 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver3 = tf.train.Saver()
            sargan_saver3 = tf.train.import_meta_graph(trained_model_path3 +
                                                       '/sargan_mnist.meta')
            sargan_saver3.restore(
                sess3, tf.train.latest_checkpoint(trained_model_path3))
            for ibatch in range(NUM_TEST_PER_EPOCH):
                processed_batch = sess3.run(sargan_model3.gen_img,
                                            feed_dict={
                                                sargan_model3.image:
                                                encoded_data[ibatch],
                                                sargan_model3.cond:
                                                encoded_data[ibatch]
                                            })
                encoded_data[ibatch] = (np.copy(processed_batch))

    return encoded_data, original_data
def main(args):
    image_number=0;
    model = SARGAN(img_size, BATCH_SIZE, img_channel=img_size[2])
    with tf.variable_scope("d_opt",reuse=tf.AUTO_REUSE):
        d_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(model.d_loss, var_list=model.d_vars)
    with tf.variable_scope("g_opt",reuse=tf.AUTO_REUSE):
        g_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(model.g_loss, var_list=model.g_vars)
    if(retrain==0):
        saver = tf.train.Saver(max_to_keep=20)
    else:
        saver = tf.train.import_meta_graph(trained_model_path+'/sargan_mnist.meta');
    
    gpu_options = tf.GPUOptions(allow_growth=True, visible_device_list=str(GPU_ID))
    config = tf.ConfigProto(gpu_options=gpu_options)
    
    progress_bar = tqdm(range(MAX_EPOCH), unit="epoch")
    #list of loss values each item is the loss value of one ieteration
    train_d_loss_values = []
    train_g_loss_values = []
    
    
    #test_imgs, test_classes = get_data(test_filename)
    #imgs, classes = get_data(train_filename)
    with tf.Session(config=config) as sess:
        if(retrain==0):
            sess.run(tf.global_variables_initializer())
        else:
            sess.run(tf.global_variables_initializer())
            saver.restore(sess,tf.train.latest_checkpoint(trained_model_path));
        #test_copies = test_imgs.astype('float32')
        for epoch in progress_bar:
            NUM_TEST_PER_EPOCH = 1
            counter = 0
            epoch_start_time = time.time()
            encoded_data, original_data, val_data, val_original=transfrom_data(NUM_TEST_PER_EPOCH)
            #shuffle(copies)
            #divide the images into equal sized batches
            #image_batches = np.array(list(chunks(copies, BATCH_SIZE)))
            for i in range (NUM_ITERATION):
                #getting a batch from the training data
                #one_batch_of_imgs = image_batches[i]                
                #copy the batch
                features=original_data[i]
                #corrupt the images
                corrupted_batch = encoded_data[i]
                _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image:features, model.cond:corrupted_batch})
                _, M = sess.run([g_opt, model.g_loss], feed_dict={model.image:features, model.cond:corrupted_batch})
                train_d_loss_values.append(m)
                train_g_loss_values.append(M)
                #print some notifications
                counter += 1
                if counter % 25 == 0:
                    print("\rEpoch [%d], Iteration [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, counter, time.time() - epoch_start_time, m, M))
                
            # save the trained network
            if epoch % SAVE_EVERY_EPOCH == 0:
                save_path = saver.save(sess, (trained_model_path+"/sargan_mnist"))
                print("\n\nModel saved in file: %s\n\n" % save_path)
            ''' +\
                                   "%s_model_%s.ckpt" % ( experiment_name, epoch+1))'''
            
            
            ##### TESTING FOR CURRUNT EPOCH
            
           
            #test_batches = np.array(list(chunks(test_copies, BATCH_SIZE)))
            #test_images = test_batches[0]
            sum_psnr = 0
            list_images = []
            for j in range(NUM_TEST_PER_EPOCH):
                features=val_original[j]
                #corrupt the images
                corrupted_batch = val_data[j]
                gen_imgs = sess.run(model.gen_img, feed_dict={model.image:features, model.cond:corrupted_batch})
                print(features.shape, gen_imgs.shape)
                #if j %17 == 0: # only save 3 images 0, 17, 34
                list_images.append((features[0], corrupted_batch[0], gen_imgs[0]))
                list_images.append((features[17], corrupted_batch[17], gen_imgs[17]))                
                list_images.append((features[34], corrupted_batch[34], gen_imgs[34]))
                for i in range(len(gen_imgs)):
                    current_img = features[i]
                    recovered_img = gen_imgs[i]
                    sum_psnr += ski_me.compare_psnr(current_img, recovered_img, 1)
                #psnr_value = ski_mem.compare_psnr(test_img, gen_img, 1)
                #sum_psnr += psnr_value
            average_psnr = sum_psnr / 50
            
            epoch_running_time = time.time() - epoch_start_time
            ############### SEND EMAIL ##############
            rows = 1
            cols = 3
            display_mean = np.array([0.485, 0.456, 0.406])
            display_std = np.array([0.229, 0.224, 0.225])
            if epoch % SAVE_EVERY_EPOCH == 0: 
                #image = std * image + mean
                imgs_1 = list_images[0]
                imgs_2 = list_images[1]
                imgs_3 = list_images[2]
                imgs_1 = display_std * imgs_1 + display_mean
                imgs_2 = display_std * imgs_2 + display_mean
                imgs_3 = display_std * imgs_3 + display_mean
                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_1[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_1[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_1[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_1 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_1.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_1 = os.path.join(output_path, 'image_%d_1.jpg' % image_number)
                plt.savefig(sample_test_file_1, dpi=300)
                
                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_2[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_2[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_2[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_2 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_2.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_2 = os.path.join(output_path, 'image_%d_2.jpg' % image_number)
                plt.savefig(sample_test_file_2, dpi=300)
     
                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_3[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_3[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_3[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_3 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_3.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_3 = os.path.join(output_path, 'image_%d_3.jpg' % image_number)
                image_number+=1
                plt.savefig(sample_test_file_3, dpi=300)
            plt.close("all")            
def main(args):
    model = SARGAN(img_size, BATCH_SIZE, img_channel=img_size[2])
    with tf.variable_scope("d_opt", reuse=tf.AUTO_REUSE):
        d_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(
            model.d_loss, var_list=model.d_vars)
    with tf.variable_scope("g_opt", reuse=tf.AUTO_REUSE):
        g_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(
            model.g_loss, var_list=model.g_vars)
    if (retrain == 0):
        saver = tf.train.Saver(max_to_keep=20)
    else:
        saver = tf.train.import_meta_graph(trained_model_path +
                                           '/sargan_mnist.meta')

    gpu_options = tf.GPUOptions(allow_growth=True,
                                visible_device_list=str(GPU_ID))
    config = tf.ConfigProto(gpu_options=gpu_options)

    progress_bar = tqdm(range(MAX_EPOCH), unit="epoch")
    #list of loss values each item is the loss value of one ieteration
    train_d_loss_values = []
    train_g_loss_values = []

    #test_imgs, test_classes = get_data(test_filename)
    #imgs, classes = get_data(train_filename)
    with tf.Session(config=config) as sess:
        if (retrain == 0):
            sess.run(tf.global_variables_initializer())
        else:
            sess.run(tf.global_variables_initializer())
            saver.restore(sess, tf.train.latest_checkpoint(trained_model_path))
        #test_copies = test_imgs.astype('float32')
        for epoch in progress_bar:
            NUM_TEST_PER_EPOCH = 1
            counter = 0
            epoch_start_time = time.time()
            encoded_data, original_data = transfrom_data(NUM_TEST_PER_EPOCH)
            #shuffle(copies)
            #divide the images into equal sized batches
            #image_batches = np.array(list(chunks(copies, BATCH_SIZE)))
            for i in range(NUM_ITERATION):
                #getting a batch from the training data
                #one_batch_of_imgs = image_batches[i]
                #copy the batch
                features = original_data[i]
                #corrupt the images
                corrupted_batch = encoded_data[i]
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: features,
                                    model.cond: corrupted_batch
                                })
                _, M = sess.run([g_opt, model.g_loss],
                                feed_dict={
                                    model.image: features,
                                    model.cond: corrupted_batch
                                })
                train_d_loss_values.append(m)
                train_g_loss_values.append(M)
                #print some notifications
                counter += 1
                if counter % 25 == 0:
                    print("\rEpoch [%d], Iteration [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, counter, time.time() - epoch_start_time, m, M))

            # save the trained network
            if epoch % SAVE_EVERY_EPOCH == 0:
                save_path = saver.save(sess,
                                       (trained_model_path + "/sargan_mnist"))
                print("\n\nModel saved in file: %s\n\n" % save_path)
def evaluate_checkpoint(filename):
    #sys.stdout = open(os.devnull, 'w')
    g2 = tf.Graph()
    gx2 = tf.Graph()
    gx3 = tf.Graph()
    gx4 = tf.Graph()
    g3 = tf.Graph()
    loop_list_adv = np.zeros([2, number_of_runs])
    loop_list_auto = np.zeros([2, number_of_runs])
    epsilonc = starting_pert
    with tf.Session() as sess:
        # Restore the checkpoint
        saver.restore(sess, tf.train.latest_checkpoint(model_dir))

        # Iterate over the samples batch-by-batch
        num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
        x_corr_list = []
        x_adv_list = []
        y_batch_list = []

        train_loader = get_data(BATCH_SIZE)
        trainiter = iter(cycle(train_loader))
        for ibatch in range(num_batches):

            x_batch2, y_batch = next(trainiter)
            y_batch_list.append(y_batch)
            x_batch2 = np.array(x_batch2.data.numpy().transpose(0, 2, 3, 1))
            x_batch = np.zeros([len(x_batch2), img_size[0] * img_size[1]])
            for i in range(len(x_batch2)):
                x_batch[i] = x_batch2[i].reshape([img_size[0] * img_size[1]])
            x_batch_adv = attacks[ibatch].perturb(x_batch, y_batch, sess)
            x_batch_adv2 = np.zeros(
                [len(x_batch), img_size[0], img_size[1], img_size[2]])
            for k in range(len(x_batch)):
                x_batch_adv2[k] = add_gaussian_noise(x_batch_adv[k].reshape(
                    [img_size[0], img_size[1], img_size[2]]),
                                                     sd=np.random.uniform(
                                                         NOISE_STD_RANGE[1],
                                                         NOISE_STD_RANGE[1]))
            x_corr_list.append(x_batch_adv)
            x_adv_list.append(x_batch_adv2)

    with g2.as_default():
        with tf.Session() as sess2:
            sargan_model = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver = tf.train.Saver()
            sargan_saver = tf.train.import_meta_graph(trained_model_path +
                                                      '/sargan_mnist.meta')
            sargan_saver.restore(
                sess2, tf.train.latest_checkpoint(trained_model_path))
            for ibatch in range(num_batches):
                processed_batch = sess2.run(sargan_model.gen_img,
                                            feed_dict={
                                                sargan_model.image:
                                                x_adv_list[ibatch],
                                                sargan_model.cond:
                                                x_adv_list[ibatch]
                                            })
                x_adv_list[ibatch] = processed_batch
    with gx2.as_default():
        with tf.Session() as sessx2:
            sargan_model2 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver2 = tf.train.Saver()
            sargan_saver2 = tf.train.import_meta_graph(trained_model_path2 +
                                                       '/sargan_mnist.meta')
            sargan_saver2.restore(
                sessx2, tf.train.latest_checkpoint(trained_model_path2))
            for ibatch in range(num_batches):
                processed_batch = sessx2.run(sargan_model2.gen_img,
                                             feed_dict={
                                                 sargan_model2.image:
                                                 x_adv_list[ibatch],
                                                 sargan_model2.cond:
                                                 x_adv_list[ibatch]
                                             })
                x_adv_list[ibatch] = (processed_batch)
    with gx3.as_default():
        with tf.Session() as sessx3:
            sargan_model3 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver3 = tf.train.Saver()
            sargan_saver3 = tf.train.import_meta_graph(trained_model_path3 +
                                                       '/sargan_mnist.meta')
            sargan_saver3.restore(
                sessx3, tf.train.latest_checkpoint(trained_model_path3))
            for ibatch in range(num_batches):
                processed_batch = sessx3.run(sargan_model3.gen_img,
                                             feed_dict={
                                                 sargan_model3.image:
                                                 x_adv_list[ibatch],
                                                 sargan_model3.cond:
                                                 x_adv_list[ibatch]
                                             })
                x_adv_list[ibatch] = processed_batch
    with gx4.as_default():
        with tf.Session() as sessx4:
            sargan_model4 = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver4 = tf.train.Saver()
            sargan_saver4 = tf.train.import_meta_graph(trained_model_path4 +
                                                       '/sargan_mnist.meta')
            sargan_saver4.restore(
                sessx4, tf.train.latest_checkpoint(trained_model_path4))
            for ibatch in range(num_batches):
                processed_batch = sessx4.run(sargan_model4.gen_img,
                                             feed_dict={
                                                 sargan_model4.image:
                                                 x_adv_list[ibatch],
                                                 sargan_model4.cond:
                                                 x_adv_list[ibatch]
                                             })
                x_adv_list[ibatch] = processed_batch.reshape(
                    [len(x_batch), img_size[0] * img_size[1]])
    with g3.as_default():
        model3 = Model()
        saver2 = tf.train.Saver()
        with tf.Session() as sess3:
            saver2.restore(sess3, filename)
            for ibatch in range(num_batches):

                dict_corr = {
                    model3.x_input: x_corr_list[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                dict_adv = {
                    model3.x_input: x_adv_list[ibatch],
                    model3.y_input: y_batch_list[ibatch]
                }

                cur_corr_corr, cur_xent_corr = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_corr)

                cur_corr_adv, cur_xent_adv = sess3.run(
                    [model3.num_correct, model3.xent], feed_dict=dict_adv)

                loop_list_adv[0, ibatch] = epsilonc
                loop_list_adv[1, ibatch] = cur_corr_adv / eval_batch_size
                loop_list_auto[0, ibatch] = epsilonc
                loop_list_auto[1, ibatch] = cur_corr_corr / eval_batch_size
                epsilonc += change
            '''summary = tf.Summary(value=[
              tf.Summary.Value(tag='xent adv eval', simple_value= avg_xent_adv),
              tf.Summary.Value(tag='xent adv', simple_value= avg_xent_adv),
              tf.Summary.Value(tag='xent nat', simple_value= avg_xent_nat),
              tf.Summary.Value(tag='accuracy adv eval', simple_value= acc_adv),
              tf.Summary.Value(tag='accuracy adv', simple_value= acc_adv),
              tf.Summary.Value(tag='accuracy nat', simple_value= acc_nat)])
            summary_writer.add_summary(summary, global_step.eval(sess3))'''
    #sys.stdout = sys.__stdout__
    return loop_list_adv, loop_list_auto
Example #5
0
def evaluate_checkpoint(filename):
    #sys.stdout = open(os.devnull, 'w')
    #Different graphs for all the models
    gx1 =tf.Graph()
    gx2 =tf.Graph()
    gx3 =tf.Graph()
    gx4 =tf.Graph()
    g3=tf.Graph()
    with tf.Session() as sess:
    # Restore the checkpoint
        saver.restore(sess, filename);
    
        # Iterate over the samples batch-by-batch
        #number of batches
        num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
        total_xent_nat = 0.
        total_xent_corr = 0.
        total_corr_nat = 0.
        total_corr_corr = 0.
        
        total_corr_adv = np.zeros([4]).astype(dtype='float32')
        total_corr_blur = np.zeros([4]).astype(dtype='float32')
        total_xent_adv = np.zeros([4]).astype(dtype='float32')
        total_xent_blur = np.zeros([4]).astype(dtype='float32')
        
        #storing the various images
        x_batch_list=[]
        x_corr_list=[]
        x_blur_list1=[]
        x_adv_list1=[]
        x_blur_list2=[]
        x_adv_list2=[]
        x_blur_list3=[]
        x_adv_list3=[]
        x_blur_list4=[]
        x_adv_list4=[]
        #Storing y values
        y_batch_list=[]
        
        train_loader= get_data(BATCH_SIZE)
        trainiter = iter(cycle(train_loader))
        for ibatch in range(num_batches):
                
            x_batch2, y_batch = next(trainiter)
            y_batch_list.append(y_batch)
            x_batch2 = np.array(x_batch2.data.numpy().transpose(0,2,3,1))
            x_batch=np.zeros([len(x_batch2),img_size[0]*img_size[1]])
            for i in range(len(x_batch2)):
                x_batch[i]=x_batch2[i].reshape([img_size[0]*img_size[1]])
            x_batch_adv = attack.perturb(x_batch, y_batch, sess)
            
            x_batch2=np.zeros([len(x_batch),img_size[0],img_size[1],img_size[2]])
            x_batch_adv2=np.zeros([len(x_batch),img_size[0],img_size[1],img_size[2]])
            for k in range(len(x_batch)):
                x_batch2[k]=add_gaussian_noise(x_batch[k].reshape([img_size[0],img_size[1],img_size[2]]), sd=np.random.uniform(NOISE_STD_RANGE[1], NOISE_STD_RANGE[1]))
                x_batch_adv2[k]=add_gaussian_noise(x_batch_adv[k].reshape([img_size[0],img_size[1],img_size[2]]), sd=np.random.uniform(NOISE_STD_RANGE[1], NOISE_STD_RANGE[1]))
            
            x_batch_list.append(x_batch)
            x_corr_list.append(x_batch_adv)
            x_blur_list4.append(x_batch2)
            x_adv_list4.append(x_batch_adv2)
            
    #Running through first autoencoder
    with gx1.as_default():
        with tf.Session() as sess2:
            sargan_model=SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver= tf.train.Saver()    
            sargan_saver = tf.train.import_meta_graph(trained_model_path+'/sargan_mnist.meta');
            sargan_saver.restore(sess2,tf.train.latest_checkpoint(trained_model_path));
            for ibatch in range(num_batches):
                processed_batch=sess2.run(sargan_model.gen_img,feed_dict={sargan_model.image: x_adv_list4[ibatch], sargan_model.cond: x_adv_list4[ibatch]})
                x_adv_list4[ibatch]=processed_batch
                
                blurred_batch=sess2.run(sargan_model.gen_img,feed_dict={sargan_model.image: x_blur_list4[ibatch], sargan_model.cond: x_blur_list4[ibatch]})
                x_blur_list4[ibatch]=blurred_batch
                
                #adding images to first autoencoder data set
                x_blur_list1.append(blurred_batch.reshape([len(x_batch),img_size[0]*img_size[1]]))
                x_adv_list1.append(processed_batch.reshape([len(x_batch),img_size[0]*img_size[1]]))
    with gx2.as_default():
        with tf.Session() as sessx2:
            sargan_model2=SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver2= tf.train.Saver()    
            sargan_saver2 = tf.train.import_meta_graph(trained_model_path2+'/sargan_mnist.meta');
            sargan_saver2.restore(sessx2,tf.train.latest_checkpoint(trained_model_path2));
            for ibatch in range(num_batches):
                processed_batch=sessx2.run(sargan_model2.gen_img,feed_dict={sargan_model2.image: x_adv_list4[ibatch], sargan_model2.cond: x_adv_list4[ibatch]})
                x_adv_list4[ibatch]=(processed_batch)
                
                blurred_batch=sessx2.run(sargan_model2.gen_img,feed_dict={sargan_model2.image: x_blur_list4[ibatch], sargan_model2.cond: x_blur_list4[ibatch]})
                x_blur_list4[ibatch]=(blurred_batch)
                
                #adding images to second autoencoder data set
                x_blur_list2.append(blurred_batch.reshape([len(x_batch),img_size[0]*img_size[1]]))
                x_adv_list2.append(processed_batch.reshape([len(x_batch),img_size[0]*img_size[1]]))
    with gx3.as_default():
        with tf.Session() as sessx3:
            sargan_model3=SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver3= tf.train.Saver()    
            sargan_saver3 = tf.train.import_meta_graph(trained_model_path3+'/sargan_mnist.meta');
            sargan_saver3.restore(sessx3,tf.train.latest_checkpoint(trained_model_path3));
            for ibatch in range(num_batches):
                processed_batch=sessx3.run(sargan_model3.gen_img,feed_dict={sargan_model3.image: x_adv_list4[ibatch], sargan_model3.cond: x_adv_list4[ibatch]})
                x_adv_list4[ibatch]=processed_batch
                
                blurred_batch=sessx3.run(sargan_model3.gen_img,feed_dict={sargan_model3.image: x_blur_list4[ibatch], sargan_model3.cond: x_blur_list4[ibatch]})
                x_blur_list4[ibatch]=blurred_batch
                
                #adding images to third autoencoder data set
                x_blur_list3.append(blurred_batch.reshape([len(x_batch),img_size[0]*img_size[1]]))
                x_adv_list3.append(processed_batch.reshape([len(x_batch),img_size[0]*img_size[1]]))
    #Final autoencoder setup
    with gx4.as_default():
        with tf.Session() as sessx4:
            sargan_model4=SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver4= tf.train.Saver()    
            sargan_saver4 = tf.train.import_meta_graph(trained_model_path4+'/sargan_mnist.meta');
            sargan_saver4.restore(sessx4,tf.train.latest_checkpoint(trained_model_path4));
            for ibatch in range(num_batches):
                processed_batch=sessx4.run(sargan_model4.gen_img,feed_dict={sargan_model4.image: x_adv_list4[ibatch], sargan_model4.cond: x_adv_list4[ibatch]})
                x_adv_list4[ibatch]=processed_batch.reshape([len(x_batch),img_size[0]*img_size[1]])
                
                blurred_batch=sessx4.run(sargan_model4.gen_img,feed_dict={sargan_model4.image: x_blur_list4[ibatch], sargan_model4.cond: x_blur_list4[ibatch]})
                x_blur_list4[ibatch]=blurred_batch.reshape([len(x_batch),img_size[0]*img_size[1]])
                
    with g3.as_default():
        model3 = Model()
        saver2 = tf.train.Saver()
        with tf.Session() as sess3:
            saver2.restore(sess3, filename);
            for ibatch in range(num_batches):
                cur_xent_adv = np.zeros([4]).astype(dtype='float32')
                cur_xent_blur = np.zeros([4]).astype(dtype='float32')
                cur_corr_adv = np.zeros([4]).astype(dtype='float32')
                cur_corr_blur = np.zeros([4]).astype(dtype='float32')
                
                dict_nat = {model3.x_input: x_batch_list[ibatch],
                        model3.y_input: y_batch_list[ibatch]}
                
                dict_corr = {model3.x_input: x_corr_list[ibatch],
                        model3.y_input: y_batch_list[ibatch]}
                
                
                #First autoencoder dictionary
                dict_adv1 = {model3.x_input: x_adv_list1[ibatch],
                          model3.y_input: y_batch_list[ibatch]}
                
                dict_blur1 = {model3.x_input: x_blur_list1[ibatch],
                          model3.y_input: y_batch_list[ibatch]}
                
                
                #Second autoencoder dictionary
                dict_adv2 = {model3.x_input: x_adv_list2[ibatch],
                          model3.y_input: y_batch_list[ibatch]}
                
                dict_blur2 = {model3.x_input: x_blur_list2[ibatch],
                          model3.y_input: y_batch_list[ibatch]}
                
                
                #Third autoencoder dictionary
                dict_adv3 = {model3.x_input: x_adv_list3[ibatch],
                          model3.y_input: y_batch_list[ibatch]}
                
                dict_blur3 = {model3.x_input: x_blur_list3[ibatch],
                          model3.y_input: y_batch_list[ibatch]}
                
                
                #Fourth autoencoder dictionary
                dict_adv4 = {model3.x_input: x_adv_list4[ibatch],
                          model3.y_input: y_batch_list[ibatch]}
                
                dict_blur4 = {model3.x_input: x_blur_list4[ibatch],
                          model3.y_input: y_batch_list[ibatch]}
                
                
                #Regular Images
                cur_corr_nat, cur_xent_nat = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_nat)
                
                cur_corr_corr, cur_xent_corr = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_corr)
                
                #First autoencoder dictionary
                cur_corr_blur[0], cur_xent_blur[0] = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_blur1)
                
                cur_corr_adv[0], cur_xent_adv[0] = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_adv1)
                
                #Second autoencoder dictionary
                cur_corr_blur[1], cur_xent_blur[1] = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_blur2)
                
                cur_corr_adv[1], cur_xent_adv[1] = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_adv2)
                
                #Third autoencoder dictionary
                cur_corr_blur[2], cur_xent_blur[2] = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_blur3)
                
                cur_corr_adv[2], cur_xent_adv[2] = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_adv3)
                
                #Fourth autoencoder dictionary
                cur_corr_blur[3], cur_xent_blur[3] = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_blur4)
                
                cur_corr_adv[3], cur_xent_adv[3] = sess3.run(
                                              [model3.num_correct,model3.xent],
                                              feed_dict = dict_adv4)
                
                
                
                #Natural
                total_corr_nat += cur_corr_nat
                total_corr_corr += cur_corr_corr
                total_xent_nat += cur_xent_nat
                total_xent_corr += cur_xent_corr
                
                #running accuracy
                total_corr_adv += cur_corr_adv
                total_corr_blur += cur_corr_blur
                total_xent_adv += cur_xent_adv
                total_xent_blur += cur_xent_blur
            
            
    
            #Regual images
            avg_xent_nat = total_xent_nat / num_eval_examples
            avg_xent_corr = total_xent_corr / num_eval_examples
            acc_nat = total_corr_nat / num_eval_examples
            acc_corr = total_corr_corr / num_eval_examples
            
            #Total accuracy
            acc_adv = total_corr_adv / num_eval_examples
            acc_blur = total_corr_blur / num_eval_examples
            avg_xent_adv = total_xent_adv / num_eval_examples
            avg_xent_blur = total_xent_blur / num_eval_examples
            
    #sys.stdout = sys.__stdout__
    print("No Autoencoder")
    print('natural: {:.2f}%'.format(100 * acc_nat))
    print('Corrupted: {:.2f}%'.format(100 * acc_corr))
    print('avg nat loss: {:.4f}'.format(avg_xent_nat))
    print('avg corr loss: {:.4f} \n'.format(avg_xent_corr))
    
    print("First Autoencoder")
    print('natural with blur: {:.2f}%'.format(100 * acc_blur[0]))
    print('adversarial: {:.2f}%'.format(100 * acc_adv[0]))
    print('avg nat with blur loss: {:.4f}'.format(avg_xent_blur[0]))
    print('avg adv loss: {:.4f} \n'.format(avg_xent_adv[0]))
    
    print("Second Autoencoder")
    print('natural with blur: {:.2f}%'.format(100 * acc_blur[1]))
    print('adversarial: {:.2f}%'.format(100 * acc_adv[1]))
    print('avg nat with blur loss: {:.4f}'.format(avg_xent_blur[1]))
    print('avg adv loss: {:.4f} \n'.format(avg_xent_adv[1]))
    
    print("Third Autoencoder")
    print('natural with blur: {:.2f}%'.format(100 * acc_blur[2]))
    print('adversarial: {:.2f}%'.format(100 * acc_adv[2]))
    print('avg nat with blur loss: {:.4f}'.format(avg_xent_blur[2]))
    print('avg adv loss: {:.4f} \n'.format(avg_xent_adv[2]))
    
    print("Fourth Autoencoder")
    print('natural with blur: {:.2f}%'.format(100 * acc_blur[3]))
    print('adversarial: {:.2f}%'.format(100 * acc_adv[3]))
    print('avg nat with blur loss: {:.4f}'.format(avg_xent_blur[3]))
    print('avg adv loss: {:.4f} \n'.format(avg_xent_adv[3]))
def transfrom_data(NUM_TEST_PER_EPOCH):
    g2 = tf.Graph()
    g3 = tf.Graph()
    encoded_data = []
    original_data = []
    val_data = []
    val_original = []
    train_loader, val_loader = get_data(BATCH_SIZE, BATCH_SIZE)
    trainiter = iter(train_loader)
    testiter = iter(val_loader)

    with g2.as_default():
        with tf.Session() as sess2:
            sargan_model = SARGAN(img_size, BATCH_SIZE, img_channel=1)
            sargan_saver = tf.train.Saver()
            sargan_saver = tf.train.import_meta_graph(trained_model_path2 +
                                                      '/sargan_mnist.meta')
            sargan_saver.restore(
                sess2, tf.train.latest_checkpoint(trained_model_path2))

            for i in range(NUM_ITERATION):
                features2, labels = next(trainiter)
                features2 = features2.data.numpy().transpose(0, 2, 3, 1)
                features = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                original_images = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                for k in range(len(features)):
                    features[k] = add_gaussian_noise(features2[k],
                                                     sd=np.random.uniform(
                                                         NOISE_STD_RANGE[0],
                                                         NOISE_STD_RANGE[1]))
                    original_images[k, :, :, 0] = features2[k, :, :, 0]
                processed_batch = sess2.run(sargan_model.gen_img,
                                            feed_dict={
                                                sargan_model.image: features,
                                                sargan_model.cond: features
                                            })
                encoded_data.append(processed_batch)
                original_data.append(original_images)

            for i in range(NUM_TEST_PER_EPOCH):
                features2, labels = next(testiter)
                features2 = features2.data.numpy().transpose(0, 2, 3, 1)
                features = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                original_images = np.zeros(
                    [len(features2), img_size[0], img_size[1], 1])
                for k in range(len(features)):
                    features[k] = add_gaussian_noise(features2[k],
                                                     sd=np.random.uniform(
                                                         NOISE_STD_RANGE[0],
                                                         NOISE_STD_RANGE[1]))
                    original_images[k, :, :, 0] = features2[k, :, :, 0]
                processed_batch = sess2.run(sargan_model.gen_img,
                                            feed_dict={
                                                sargan_model.image: features,
                                                sargan_model.cond: features
                                            })
                val_original.append(original_images)
                val_data.append(processed_batch)

    return encoded_data, original_data, val_data, val_original
def main(args):
    image_number = 0
    model = SARGAN(img_size, BATCH_SIZE, img_channel=1)
    with tf.variable_scope("d_opt", reuse=tf.AUTO_REUSE):
        d_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(
            model.d_loss, var_list=model.d_vars)
    with tf.variable_scope("g_opt", reuse=tf.AUTO_REUSE):
        g_opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(
            model.g_loss, var_list=model.g_vars)
    saver = tf.train.Saver(max_to_keep=20)

    gpu_options = tf.GPUOptions(allow_growth=True,
                                visible_device_list=str(GPU_ID))
    config = tf.ConfigProto(gpu_options=gpu_options)

    progress_bar = tqdm(range(MAX_EPOCH), unit="epoch")
    #list of loss values each item is the loss value of one ieteration
    train_d_loss_values = []
    train_g_loss_values = []

    #test_imgs, test_classes = get_data(test_filename)
    #imgs, classes = get_data(train_filename)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        #copies = imgs.astype('float32')
        #test_copies = test_imgs.astype('float32')
        for epoch in progress_bar:
            train_loader, val_loader = get_data_loader(BATCH_SIZE, BATCH_SIZE)
            counter = 0
            epoch_start_time = time.time()
            #shuffle(copies)
            #divide the images into equal sized batches
            #image_batches = np.array(list(chunks(copies, BATCH_SIZE)))
            trainiter = iter(train_loader)
            for i in range(NUM_ITERATION):
                #getting a batch from the training data
                #one_batch_of_imgs = image_batches[i]
                features, labels = next(trainiter)
                features = features.data.numpy().transpose(0, 2, 3, 1)
                #copy the batch
                features_copy = features.copy()
                #corrupt the images
                corrupted_batch = np.array([
                    add_gaussian_noise(image,
                                       sd=np.random.uniform(
                                           NOISE_STD_RANGE[1],
                                           NOISE_STD_RANGE[1]))
                    for image in features_copy
                ])
                for i in range(len(corrupted_batch)):
                    corrupted_batch[i] = gaussian_filter(corrupted_batch[i],
                                                         sigma=1)
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: features,
                                    model.cond: corrupted_batch
                                })
                _, M = sess.run([g_opt, model.g_loss],
                                feed_dict={
                                    model.image: features,
                                    model.cond: corrupted_batch
                                })
                train_d_loss_values.append(m)
                train_g_loss_values.append(M)
                #print some notifications
                counter += 1
                if counter % 25 == 0:
                    print("\rEpoch [%d], Iteration [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, counter, time.time() - epoch_start_time, m, M))

            # save the trained network
            if epoch % SAVE_EVERY_EPOCH == 0:
                save_path = saver.save(sess,
                                       (trained_model_path + "/sargan_mnist"))
                print("\n\nModel saved in file: %s\n\n" % save_path)
            ''' +\
                                   "%s_model_%s.ckpt" % ( experiment_name, epoch+1))'''

            ##### TESTING FOR CURRUNT EPOCH
            testiter = iter(val_loader)
            NUM_TEST_PER_EPOCH = 1

            #test_batches = np.array(list(chunks(test_copies, BATCH_SIZE)))
            #test_images = test_batches[0]
            sum_psnr = 0
            list_images = []
            for j in range(NUM_TEST_PER_EPOCH):
                features, labels = next(testiter)
                features = features.data.numpy()
                features = features.transpose(0, 2, 3, 1)
                batch_copy = features.copy()
                #corrupt the images
                corrupted_batch = np.array([
                    add_gaussian_noise(image,
                                       sd=np.random.uniform(
                                           NOISE_STD_RANGE[0],
                                           NOISE_STD_RANGE[1]))
                    for image in batch_copy
                ])
                for i in range(len(corrupted_batch)):
                    corrupted_batch[i] = gaussian_filter(corrupted_batch[i],
                                                         sigma=1)

                gen_imgs = sess.run(model.gen_img,
                                    feed_dict={
                                        model.image: features,
                                        model.cond: corrupted_batch
                                    })
                print(features.shape, gen_imgs.shape)
                #if j %17 == 0: # only save 3 images 0, 17, 34
                list_images.append(
                    (features[0], corrupted_batch[0], gen_imgs[0]))
                list_images.append(
                    (features[17], corrupted_batch[17], gen_imgs[17]))
                list_images.append(
                    (features[34], corrupted_batch[34], gen_imgs[34]))
                for i in range(len(gen_imgs)):
                    current_img = features[i]
                    recovered_img = gen_imgs[i]
                    sum_psnr += ski_me.compare_psnr(current_img, recovered_img,
                                                    1)
                #psnr_value = ski_mem.compare_psnr(test_img, gen_img, 1)
                #sum_psnr += psnr_value
            average_psnr = sum_psnr / 50

            epoch_running_time = time.time() - epoch_start_time
            ############### SEND EMAIL ##############
            rows = 1
            cols = 3
            display_mean = np.array([0.485, 0.456, 0.406])
            display_std = np.array([0.229, 0.224, 0.225])
            if epoch % SAVE_EVERY_EPOCH == 0:
                #image = std * image + mean
                imgs_1 = list_images[0]
                imgs_2 = list_images[1]
                imgs_3 = list_images[2]
                imgs_1 = display_std * imgs_1 + display_mean
                imgs_2 = display_std * imgs_2 + display_mean
                imgs_3 = display_std * imgs_3 + display_mean
                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_1[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_1[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_1[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_1 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_1.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_1 = os.path.join(
                    output_path, 'image_%d_1.jpg' % image_number)
                plt.savefig(sample_test_file_1, dpi=300)

                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_2[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_2[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_2[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_2 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_2.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_2 = os.path.join(
                    output_path, 'image_%d_2.jpg' % image_number)
                plt.savefig(sample_test_file_2, dpi=300)

                fig = plt.figure(figsize=(14, 4))
                ax = fig.add_subplot(rows, cols, 1)
                ax.imshow(imgs_3[0])
                ax.set_title("Original", color='grey')
                ax = fig.add_subplot(rows, cols, 2)
                ax.imshow(imgs_3[1])
                ax.set_title("Corrupted", color='grey')
                ax = fig.add_subplot(rows, cols, 3)
                ax.imshow(imgs_3[2])
                ax.set_title("Recovered", color='grey')
                plt.tight_layout()
                #sample_test_file_3 = os.path.join(output_path, '%s_epoch_%s_batchsize_%s_3.jpg' % (experiment_name, epoch, BATCH_SIZE))
                sample_test_file_3 = os.path.join(
                    output_path, 'image_%d_3.jpg' % image_number)
                image_number += 1
                plt.savefig(sample_test_file_3, dpi=300)
            '''attachments = [sample_test_file_1, sample_test_file_2, sample_test_file_3]

            email_alert_subject = "Epoch %s %s SARGAN" % (epoch+1, experiment_name.upper())
            
            email_alert_text = """
            Epoch [{}/{}]
            EXPERIMENT PARAMETERS:
                        Learning rate: {}
                        Batch size: {}
                        Max epoch number: {}
                        Training iterations each epoch: {}
                        noise standard deviation: {}
                        Running time: {:.2f} [s]
            AVERAGE PSNR VALUE ON 50 TEST IMAGES: {}
            """.format(epoch + 1, MAX_EPOCH, LEARNING_RATE, BATCH_SIZE, MAX_EPOCH, 
                      NUM_ITERATION, ((NOISE_STD_RANGE[0] + NOISE_STD_RANGE[1])/2), epoch_running_time, average_psnr)
            if epoch % SAVE_EVERY_EPOCH == 0: 
                send_images_via_email(email_alert_subject,
                 email_alert_text,
                 attachments,
                 sender_email="*****@*****.**", 
                 recipient_emails=["*****@*****.**"])'''
            plt.close("all")