Esempio n. 1
0
def main():
    prince = True
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--z_dim', type=int, default=10,
                       help='Noise dimension')
    
    parser.add_argument('--t_dim', type=int, default=64,
                       help='Text feature dimension')
    
    parser.add_argument('--batch_size', type=int, default=64,
                       help='Batch Size')
    
    parser.add_argument('--image_size', type=int, default=64,
                       help='Image Size a, a x a')
    
    parser.add_argument('--gf_dim', type=int, default=16,
                       help='Number of conv in the first layer gen.')
    
    parser.add_argument('--df_dim', type=int, default=16,
                       help='Number of conv in the first layer discr.')
    
    parser.add_argument('--gfc_dim', type=int, default=1024,
                       help='Dimension of gen untis for for fully connected layer 1024')
    
    parser.add_argument('--caption_vector_length', type=int, default=556,
                       help='Caption Vector Length')
    
    parser.add_argument('--data_dir', type=str, default="Data",
                       help='Data Directory')
    
    parser.add_argument('--beta1', type=float, default =.5,
                       help='Momentum for Adam Update')

    parser.add_argument('--data_set', type=str, default="flowers",
                       help='Data set: MS-COCO, flowers')

    args = parser.parse_args()
    
    save_epoch = [5,10,20,50]
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.54)
    
    cols = ['epoch_' + str(i) for i in save_epoch]
    df_loss = pd.DataFrame(index = cols, columns = cols)
    
    for i in range(len(save_epoch)):
        for j in range((len(save_epoch))):
            modelFile1 = "./Data/ModelEval/Feat_checkpoint_epoch_%d.ckpt" % (save_epoch[i])
            modelFile2 = "./Data/ModelEval/Unsupervs_checkpoint_epoch_%d.ckpt" % (save_epoch[j])
            
            model_options = {
                    'z_dim' : args.z_dim,
                    't_dim' : args.t_dim,
                    'batch_size' : args.batch_size,
                    'image_size' : args.image_size,
                    'gf_dim' : args.gf_dim,
                    'df_dim' : args.df_dim,
                    'gfc_dim' : args.gfc_dim,
                    'caption_vector_length' : args.caption_vector_length
                }
    
            gan1 = model1.GAN(model_options)
            input_tensors1, _, _, outputs1 = gan1.build_model(args.beta1, .9, 1e-4)                        
            sess1 = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
            
            if prince:
                sess1.run(tf.global_variables_initializer())
            else:
                tf.initialize_all_variables().run()            
            saver = tf.train.Saver()
            # Restore the first model:
            saver.restore(sess1, modelFile1)        
            loaded_data = load_training_data(args.data_dir, args.data_set)
            batch_no = 0
            _, _, caption_vectors, z_noise, _ = get_training_batch(batch_no, args.batch_size, 
                args.image_size, args.z_dim, args.caption_vector_length, 'train', args.data_dir, args.data_set, loaded_data)        
            # Get output image from first model
            g1_img3 = sess1.run(outputs1['img3'],
                feed_dict = {
                    input_tensors1['t_real_caption'] : caption_vectors,
                    input_tensors1['t_z'] : z_noise,
                    input_tensors1['noise_indicator'] : 0,
                    input_tensors1['noise_gen'] : 0
                })
            g1_d1_p_3_gen_img_logit, g1_d1_p_3_gen_txt_logit = sess1.run(
                [outputs1['output_p_3_gen_img_logit'],outputs1['output_p_3_gen_txt_logit']],
                feed_dict= {
                    input_tensors1['t_real_caption'] : caption_vectors,
                            input_tensors1['t_z'] : z_noise,
                            input_tensors1['noise_indicator'] : 0,
                            input_tensors1['gen_image1'] : g1_img3,
                            input_tensors1['noise_disc'] : 0,
                            input_tensors1['noise_gen'] : 0
                    })
            #print('g1_d1_p_3_gen_img_logit:')
            #print(g1_d1_p_3_gen_img_logit)
            g1_d1_loss = cross_entropy(g1_d1_p_3_gen_img_logit, np.ones((args.batch_size, 1))) + cross_entropy(g1_d1_p_3_gen_txt_logit, np.ones((args.batch_size, 1)))            
            tf.reset_default_graph()
            sess1.close()            
            # Create second model
            model_options = {
                    'z_dim' : args.z_dim,
                    't_dim' : args.t_dim,
                    'batch_size' : args.batch_size,
                    'image_size' : args.image_size,
                    'gf_dim' : args.gf_dim,
                    'df_dim' : args.df_dim,
                    'gfc_dim' : args.gfc_dim,
                    'caption_vector_length' : args.caption_vector_length
                }            
            gan2 = model2.GAN(model_options)
            input_tensors2, _, _, outputs2 = gan2.build_model(args.beta1, .9, 1e-4)        
            sess2 = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))        
            if prince:
                sess2.run(tf.global_variables_initializer())
            else:
                tf.initialize_all_variables().run()            
            saver2 = tf.train.Saver()            
            saver2.restore(sess2, modelFile2)            
            # Get logits from the second model
            g1_d2_p_3_gen_img_logit, g1_d2_p_3_gen_txt_logit = sess2.run(
                [outputs2['output_p_3_gen_img_logit'],outputs2['output_p_3_gen_txt_logit']],
                feed_dict= {
                    input_tensors2['t_real_caption'] : caption_vectors,
                            input_tensors2['t_z'] : z_noise,
                            input_tensors2['noise_indicator'] : 0,
                            input_tensors2['gen_image1'] : g1_img3,
                            input_tensors2['noise_disc'] : 0,
                            input_tensors2['noise_gen'] : 0
                    })            
            g1_d2_loss = cross_entropy(g1_d2_p_3_gen_img_logit, np.ones((args.batch_size, 1))) + cross_entropy(g1_d2_p_3_gen_txt_logit, np.ones((args.batch_size, 1)))            
            # Get output image from second model
            g2_img3 = sess2.run(outputs2['img3'],
                feed_dict = {
                    input_tensors2['t_real_caption'] : caption_vectors,
                    input_tensors2['t_z'] : z_noise,
                    input_tensors2['noise_indicator'] : 0,
                    input_tensors2['noise_gen'] : 0
                })            
            # Get logits from the second model
            g2_d2_p_3_gen_img_logit, g2_d2_p_3_gen_txt_logit = sess2.run(
                [outputs2['output_p_3_gen_img_logit'],outputs2['output_p_3_gen_txt_logit']],
                feed_dict= {
                    input_tensors2['t_real_caption'] : caption_vectors,
                            input_tensors2['t_z'] : z_noise,
                            input_tensors2['noise_indicator'] : 0,
                            input_tensors2['gen_image1'] : g2_img3,
                            input_tensors2['noise_disc'] : 0,
                            input_tensors2['noise_gen'] : 0
                    })            
            g2_d2_loss = cross_entropy(g2_d2_p_3_gen_img_logit, np.ones((args.batch_size, 1))) + cross_entropy(g2_d2_p_3_gen_txt_logit, np.ones((args.batch_size, 1)))            
            tf.reset_default_graph()
            sess2.close()            
            model_options = {
                    'z_dim' : args.z_dim,
                    't_dim' : args.t_dim,
                    'batch_size' : args.batch_size,
                    'image_size' : args.image_size,
                    'gf_dim' : args.gf_dim,
                    'df_dim' : args.df_dim,
                    'gfc_dim' : args.gfc_dim,
                    'caption_vector_length' : args.caption_vector_length
                }            
            gan1 = model1.GAN(model_options)
            input_tensors1, _, _, outputs1 = gan1.build_model(args.beta1, .9, 1e-4)            
            sess1 = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))            
            if prince:
                sess1.run(tf.global_variables_initializer())
            else:
                tf.initialize_all_variables().run()          
            saver = tf.train.Saver()          
            saver.restore(sess1, modelFile1)
            # Get logits from the first model
            g2_d1_p_3_gen_img_logit, g2_d1_p_3_gen_txt_logit = sess1.run(
                [outputs1['output_p_3_gen_img_logit'],outputs1['output_p_3_gen_txt_logit']],
                feed_dict= {
                    input_tensors1['t_real_caption'] : caption_vectors,
                            input_tensors1['t_z'] : z_noise,
                            input_tensors1['noise_indicator'] : 0,
                            input_tensors1['gen_image1'] : g2_img3,
                            input_tensors1['noise_disc'] : 0,
                            input_tensors1['noise_gen'] : 0
                    })
            #print('g1_d1_p_3_gen_img_logit:')
            #print(g1_d1_p_3_gen_img_logit)
            #print('g2_d1_p_3_gen_img_logit:')
            #print(g2_d1_p_3_gen_img_logit)
            tf.reset_default_graph()
            sess1.close()
            g2_d1_loss = cross_entropy(g2_d1_p_3_gen_img_logit, np.ones((args.batch_size, 1))) + cross_entropy(g2_d1_p_3_gen_txt_logit, np.ones((args.batch_size, 1)))
            g1_wins_on_d2 = 0
            g2_wins_on_d1 = 0
            for idx in range(g2_d1_loss.shape[0]):
                # Compare loss on disc 1
                if g1_d1_loss[idx][0] > g2_d1_loss[idx][0]:
                    g2_wins_on_d1 += 1
                    print(g2_d1_loss[idx][0],'<', g1_d1_loss[idx][0], 'g2 wins on d1')
                else:
                    print(g2_d1_loss[idx][0],'>', g1_d1_loss[idx][0], 'g1 wins on d1')                
                
                # Compare loss on disc 2
                if g1_d2_loss[idx][0] < g2_d2_loss[idx][0]:
                    g1_wins_on_d2 += 1
                    print(g1_d2_loss[idx][0],'<', g2_d2_loss[idx][0], 'g1 wins on d2')
                else:
                    print(g1_d2_loss[idx][0],'>', g2_d2_loss[idx][0], 'g2 wins on d2')
                
            df_loss.loc[cols[i],cols[j]] = str(g2_wins_on_d1)+'/'+str(g1_wins_on_d2)
    df_loss.to_csv('Feat_Uns.csv')
Esempio n. 2
0
def main():
    class Logger(object):
        def __init__(self, filename="last_run_output.txt"):
            self.terminal = sys.stdout
            self.log = open(filename, "a")

        def write(self, message):
            self.terminal.write(message)
            self.log.write(message)
            self.flush()

        def flush(self):
            self.log.flush()

    sys.stdout = Logger("logs/" + str(os.path.basename(sys.argv[0])) +
                        str(time.time()) + ".txt")

    parser = argparse.ArgumentParser()
    parser.add_argument('--z_dim',
                        type=int,
                        default=10,
                        help='Noise dimension')

    parser.add_argument(
        '--t_dim',
        type=int,
        default=64,  #1024#LEE,#default=256,
        help='Text feature dimension')

    parser.add_argument(
        '--batch_size',
        type=int,
        default=64,  #LEECHANGE default=64,
        help='Batch Size')

    parser.add_argument('--image_size',
                        type=int,
                        default=64,
                        help='Image Size a, a x a')

    parser.add_argument(
        '--gf_dim',
        type=int,
        default=16,  #30
        help='Number of conv in the first layer gen.')

    parser.add_argument(
        '--df_dim',
        type=int,
        default=16,  #128,#12
        help='Number of conv in the first layer discr.')

    parser.add_argument(
        '--gfc_dim',
        type=int,
        default=1024,
        help='Dimension of gen untis for for fully connected layer 1024')

    parser.add_argument(
        '--caption_vector_length',
        type=int,
        default=556,  #4096 - zdim (30)#2400Lee
        help='Caption Vector Length')

    parser.add_argument('--data_dir',
                        type=str,
                        default="Data",
                        help='Data Directory')

    parser.add_argument(
        '--learning_rate',
        type=float,
        default=1e-4,  #1e-4 or 1e-5LEECHANGE default=0.0002,
        help='Learning Rate')

    parser.add_argument(
        '--beta1',
        type=float,
        default=.5,  #LEECHANGE default=0.5,
        help='Momentum for Adam Update')

    parser.add_argument('--epochs',
                        type=int,
                        default=51,
                        help='Max number of epochs')

    parser.add_argument(
        '--save_every',
        type=int,
        default=30,
        help='Save Model/Samples every x iterations over batches')

    parser.add_argument('--resume_model',
                        type=str,
                        default=None,
                        help='Pre-Trained Model Path, to resume from')

    parser.add_argument('--data_set',
                        type=str,
                        default="flowers",
                        help='Dat set: MS-COCO, flowers')

    parser.add_argument('--save_epoch',
                        type=list,
                        default=[5, 10, 20, 50],
                        help='Save model in specified epoch')

    args = parser.parse_args()

    word_dict = load_flower_dict("dictionary.txt")
    flowerdataDir = 'Data/'
    augDataDir = '/Data/augment/'

    flower_loaded_data = load_training_data(flowerdataDir, 'flowers')
    flower_image_captions = load_flower_captions("Data")
    aug_loaded_data = load_training_data(augDataDir, 'augment')
    aug_image_captions = aug_loaded_data['captions']

    model_options = {
        'z_dim': args.z_dim,
        't_dim': args.t_dim,
        'batch_size': args.batch_size,
        'image_size': args.image_size,
        'gf_dim': args.gf_dim,
        'df_dim': args.df_dim,
        'gfc_dim': args.gfc_dim,
        'caption_vector_length': args.caption_vector_length
    }
    beta2 = .9

    gan = model.GAN(model_options)
    input_tensors, variables, loss, outputs = gan.build_model(
        args.beta1, beta2, args.learning_rate)

    g_optim = gan.g_optim
    d_optim = gan.d_optim
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.54)

    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    #sess = tf.InteractiveSession()
    if prince:
        sess.run(tf.global_variables_initializer())
        '''
        sess.run(tf.global_variables_initializer(),feed_dict = {
            input_tensors['t_real_image'] : np.zeros((64,64,64,3)),
            input_tensors['t_real_caption'] : np.zeros((64, 100)),
            input_tensors['t_z'] : np.zeros((64,10)),
            input_tensors['noise_gen'] : 0,
            input_tensors['noise_disc'] : 0})
        '''
    else:
        tf.initialize_all_variables().run()

    saver = tf.train.Saver()
    if args.resume_model:
        saver.restore(sess, args.resume_model)

    loaded_data = load_training_data(args.data_dir, args.data_set)
    init = -1
    skip_n = 0
    d_avg_full, d_avg_mid, d_avg_sml = 0, 0, 0
    lb = 0  #-.3
    disc_break = -.3
    t1 = 0
    t2 = 0
    d_loss_noise = g_loss_noise = 0
    d1_loss = d2_loss = d3_loss = 3
    num_disc_steps = 0
    time_size = 10368  #df dim 30 : 19440
    real_time = np.zeros(time_size)
    fake_time = np.zeros(time_size)
    for i in range(args.epochs):
        noise_gen = .1 / np.sqrt(i + 1)
        noise_disc = .1 / np.sqrt(i + 1)
        t = time.time()
        img_idx = 0
        batch_no = 0
        caption_vectors = 0
        mul1 = 1
        mul2 = 1
        mul3 = 1
        flower_dataset_batch_size = 64
        while batch_no * flower_dataset_batch_size < loaded_data['data_length']:
            trans_mult = 5 / np.sqrt(25 + i)
            lr = 1e-4 / np.sqrt(1 + i)
            # real_images, wrong_images, caption_vectors, z_noise, image_files = get_training_batch(batch_no, args.batch_size,
            #     args.image_size, args.z_dim, args.caption_vector_length, 'train', args.data_dir, args.data_set, loaded_data)

            ## Zizhuo REN ###########
            annotations, caption_vectors, real_images = get_combination_batch(
                64, flower_dataset_batch_size,
                args.batch_size - flower_dataset_batch_size, batch_no,
                flower_loaded_data, flower_image_captions, aug_loaded_data,
                aug_image_captions, word_dict)
            wrong_images = np.concatenate(
                (real_images[30:, :, :, :], real_images[:30, :, :, :]), 0)
            #########################
            z_noise = np.random.rand(args.batch_size, 10)
            c_off1 = np.random.randint(2, 10)
            c_off2 = np.random.randint(11, 20)
            caption_wrong1 = np.concatenate(
                (caption_vectors[c_off1:], caption_vectors[:c_off1]))
            caption_wrong2 = np.concatenate(
                (caption_vectors[c_off2:], caption_vectors[:c_off2]))
            t1 += time.time() - t
            t = time.time()
            # DISCR UPDATE
            lambdaDAE = 10
            horrible_images = super_blur(real_images)
            rt, ft, _, g_loss, g1_loss, g2_loss, g3_loss, \
                img1, img2,img3   = sess.run(
                [gan.real_acts, gan.fake_acts, g_optim, loss['g_loss'],
                 loss['g1_loss'], loss['g2_loss'], loss['g3_loss'],
                outputs['img1'],outputs['img2'], outputs['img3']],
                feed_dict = {
                    input_tensors['t_real_image'] : real_images,
                    input_tensors['t_real_caption'] : caption_vectors,
                    input_tensors['cap_wrong1'] : caption_wrong1,
                    input_tensors['cap_wrong2'] : caption_wrong2,
                    input_tensors['t_z'] : z_noise,
                    input_tensors['l2reg']: 0,
                    input_tensors['LambdaDAE']: lambdaDAE,
                    input_tensors['noise_indicator'] : 0,
                    input_tensors['noise_gen'] : noise_gen,
                    input_tensors['noise_disc'] : noise_disc,
                    input_tensors['mul1'] : mul1,
                    input_tensors['mul2'] : mul2,
                    input_tensors['mul3'] : mul3,
                    gan.past_reals: real_time,
                    gan.past_fakes: fake_time,
                    gan.trans_mult : trans_mult,
                    input_tensors['lr'] : lr
                })

            real_time = real_time * .9 + rt * .1
            fake_time = fake_time * .9 + ft * .1

            d1_loss, d2_loss, d3_loss = sess.run(
                [
                    loss['d1_loss_gen'], loss['d2_loss_gen'],
                    loss['d3_loss_gen']
                ],
                feed_dict={
                    input_tensors['t_real_image']: real_images,
                    input_tensors['t_real_caption']: caption_vectors,
                    input_tensors['cap_wrong1']: caption_wrong1,
                    input_tensors['cap_wrong2']: caption_wrong2,
                    input_tensors['LambdaDAE']: lambdaDAE,
                    input_tensors['t_z']: z_noise,
                    input_tensors['noise_indicator']: 0,
                    input_tensors['noise_gen']: noise_gen,
                    input_tensors['mul1']: mul1,
                    input_tensors['mul2']: mul2,
                    input_tensors['mul3']: mul3,
                    input_tensors['gen_image1']: img3,
                    input_tensors['gen_image2']: img2,
                    input_tensors['gen_image4']: img1,
                    input_tensors['noise_disc']: noise_disc
                })
            if np.min([d1_loss, d2_loss, d3_loss]) > 1.2 or init < 0:
                if init < 0:
                    init += 1
                num_disc_steps += 1
                print('running real disc')
                if init < 0:
                    init += 1
                sess.run(
                    [d_optim],
                    feed_dict={
                        input_tensors['t_real_image']: real_images,
                        input_tensors['t_wrong_image']: wrong_images,
                        input_tensors['t_horrible_image']: horrible_images,
                        input_tensors['t_real_caption']: caption_vectors,
                        input_tensors['cap_wrong1']: caption_wrong1,
                        input_tensors['cap_wrong2']: caption_wrong2,
                        input_tensors['t_z']: z_noise,
                        input_tensors['l2reg']: 0,
                        input_tensors['LambdaDAE']: lambdaDAE,
                        input_tensors['noise_indicator']: 0,
                        input_tensors['noise_gen']: noise_gen,
                        input_tensors['noise_disc']: noise_disc,
                        input_tensors['mul1']: mul1,
                        input_tensors['mul2']: mul2,
                        input_tensors['mul3']: mul3,
                        gan.trans_mult: trans_mult,
                        input_tensors['gen_image1']: img3,
                        input_tensors['gen_image2']: img2,
                        input_tensors['gen_image4']: img1,
                        input_tensors['lr']: lr
                    })
            else:
                sess.run(gan.l2_disc)
                rt, ft, _, g_loss, g1_loss, g2_loss, g3_loss, \
                    img1, img2,img3   = sess.run(
                    [gan.real_acts, gan.fake_acts, g_optim, loss['g_loss'],
                     loss['g1_loss'], loss['g2_loss'], loss['g3_loss'],
                    outputs['img1'],outputs['img2'], outputs['img3']],
                    feed_dict = {
                        input_tensors['t_real_image'] : real_images,
                        input_tensors['t_real_caption'] : caption_vectors,
                        input_tensors['cap_wrong1'] : caption_wrong1,
                        input_tensors['cap_wrong2'] : caption_wrong2,
                        input_tensors['t_z'] : z_noise,
                        input_tensors['l2reg']: 0,
                        input_tensors['LambdaDAE']: lambdaDAE,
                        input_tensors['noise_indicator'] : 0,
                        input_tensors['noise_gen'] : noise_gen,
                        input_tensors['noise_disc'] : noise_disc,
                        input_tensors['mul1'] : mul1,
                        input_tensors['mul2'] : mul2,
                        input_tensors['mul3'] : mul3,
                        gan.past_reals: real_time,
                        gan.past_fakes: fake_time,
                        gan.trans_mult : trans_mult,
                        input_tensors['lr'] : lr
                    })
                real_time = real_time * .9 + rt * .1
                fake_time = fake_time * .9 + ft * .1

            ############NOISE!
            rt, ft, _, g_loss_noise, g1_loss_noise, g2_loss_noise, g3_loss_noise, \
                img1_noise, img2_noise,img3_noise   = sess.run(
                [gan.real_acts, gan.fake_acts, g_optim, loss['g_loss'],
                 loss['g1_loss_noise'], loss['g2_loss_noise'], loss['g3_loss_noise'],
                outputs['img1'],outputs['img2'], outputs['img3']],
                feed_dict = {
                    input_tensors['t_real_image'] : real_images,
                    input_tensors['t_real_caption'] : np.random.rand(
                                args.batch_size, args.caption_vector_length)*.2,
                    input_tensors['cap_wrong1'] : np.random.rand(
                                args.batch_size, args.caption_vector_length)*.2,
                    input_tensors['cap_wrong2'] : np.random.rand(
                                args.batch_size, args.caption_vector_length)*.2,
                    input_tensors['t_z'] : z_noise,
                    input_tensors['l2reg']: 0,
                    input_tensors['LambdaDAE']: lambdaDAE,
                    input_tensors['noise_indicator'] : 1,
                    input_tensors['noise_gen'] : noise_gen,
                    input_tensors['noise_disc'] : noise_disc,
                    input_tensors['mul1'] : mul1,
                    input_tensors['mul2'] : mul2,
                    input_tensors['mul3'] : mul3,
                    gan.past_reals: real_time,
                    gan.past_fakes: fake_time,
                    gan.trans_mult : trans_mult,
                    input_tensors['lr'] : lr
                })
            real_time = real_time * .9 + rt * .1
            fake_time = fake_time * .9 + ft * .1

            d1_loss_noise, d2_loss_noise, d3_loss_noise = sess.run(
                [
                    loss['d1_loss_gen_noise'], loss['d2_loss_gen_noise'],
                    loss['d3_loss_gen_noise']
                ],
                feed_dict={
                    input_tensors['t_real_image']:
                    real_images,
                    input_tensors['t_real_caption']:
                    np.random.rand(args.batch_size,
                                   args.caption_vector_length) * .2,
                    input_tensors['cap_wrong1']:
                    np.random.rand(args.batch_size,
                                   args.caption_vector_length) * .2,
                    input_tensors['cap_wrong2']:
                    np.random.rand(args.batch_size, args.caption_vector_length)
                    * .2,
                    input_tensors['LambdaDAE']:
                    lambdaDAE,
                    input_tensors['t_z']:
                    z_noise,
                    input_tensors['noise_indicator']:
                    1,
                    input_tensors['noise_gen']:
                    noise_gen,
                    input_tensors['mul1']:
                    mul1,
                    input_tensors['mul2']:
                    mul2,
                    input_tensors['mul3']:
                    mul3,
                    input_tensors['gen_image1']:
                    img3_noise,
                    input_tensors['gen_image2']:
                    img2_noise,
                    input_tensors['gen_image4']:
                    img1_noise,
                    input_tensors['noise_gen']:
                    noise_gen,
                    input_tensors['noise_disc']:
                    noise_disc
                })
            if np.min([d1_loss_noise, d2_loss_noise, d3_loss_noise]) > .6:
                num_disc_steps += 1
                if init < 0:
                    init += 1
                print('running noise disc')
                horrible_images = super_blur(real_images)
                sess.run(
                    [d_optim],
                    feed_dict={
                        input_tensors['t_real_image']:
                        real_images,
                        input_tensors['t_wrong_image']:
                        wrong_images,
                        input_tensors['t_horrible_image']:
                        horrible_images,
                        input_tensors['t_real_caption']:
                        np.random.rand(args.batch_size,
                                       args.caption_vector_length) * .2,
                        input_tensors['cap_wrong1']:
                        np.random.rand(args.batch_size,
                                       args.caption_vector_length) * .2,
                        input_tensors['cap_wrong2']:
                        np.random.rand(args.batch_size,
                                       args.caption_vector_length) * .2,
                        input_tensors['t_z']:
                        z_noise,
                        input_tensors['l2reg']:
                        0,
                        input_tensors['LambdaDAE']:
                        lambdaDAE,
                        input_tensors['noise_indicator']:
                        1,
                        input_tensors['noise_gen']:
                        noise_gen,
                        input_tensors['noise_disc']:
                        noise_disc,
                        input_tensors['mul1']:
                        mul1,
                        input_tensors['mul2']:
                        mul2,
                        input_tensors['mul3']:
                        mul3,
                        gan.past_reals:
                        real_time,
                        gan.past_fakes:
                        fake_time,
                        gan.trans_mult:
                        trans_mult,
                        input_tensors['gen_image1']:
                        img3_noise,
                        input_tensors['gen_image2']:
                        img2_noise,
                        input_tensors['gen_image4']:
                        img1_noise,
                        input_tensors['lr']:
                        lr
                    })
            else:
                sess.run(gan.l2_disc)
                rt, ft, _, g_loss_noise, g1_loss_noise, g2_loss_noise, g3_loss_noise, \
                    img1_noise, img2_noise,img3_noise   = sess.run(
                    [gan.real_acts, gan.fake_acts, g_optim, loss['g_loss'],
                     loss['g1_loss_noise'], loss['g2_loss_noise'], loss['g3_loss_noise'],
                    outputs['img1'],outputs['img2'], outputs['img3']],
                    feed_dict = {
                        input_tensors['t_real_image'] : real_images,
                        input_tensors['t_real_caption'] : np.random.rand(
                                    args.batch_size, args.caption_vector_length)*.2,
                        input_tensors['cap_wrong1'] : np.random.rand(
                                    args.batch_size, args.caption_vector_length)*.2,
                        input_tensors['cap_wrong2'] : np.random.rand(
                                    args.batch_size, args.caption_vector_length)*.2,
                        input_tensors['t_z'] : z_noise,
                        input_tensors['l2reg']: 0,
                        input_tensors['LambdaDAE']: lambdaDAE,
                        input_tensors['noise_indicator'] : 1,
                        input_tensors['noise_gen'] : noise_gen,
                        input_tensors['noise_disc'] : noise_disc,
                        gan.past_reals: real_time,
                        gan.past_fakes: fake_time,
                        input_tensors['mul1'] : mul1,
                        input_tensors['mul2'] : mul2,
                        input_tensors['mul3'] : mul3,
                        gan.trans_mult : trans_mult,
                        input_tensors['lr'] : lr
                    })

                real_time = real_time * .9 + rt * .1
                fake_time = fake_time * .9 + ft * .1

            t2 += time.time() - t
            t = time.time()

            if batch_no % 2 == 0:
                img1, img2, img3, trans1, trans2, trans3, trans21, trans22, trans23 = sess.run(
                    [
                        outputs['img1'], outputs['img2'], outputs['img3'],
                        outputs['trans1'], outputs['trans2'],
                        outputs['trans3'], outputs['trans21'],
                        outputs['trans22'], outputs['trans23']
                    ],
                    feed_dict={
                        input_tensors['t_real_image']: real_images,
                        input_tensors['t_real_caption']: caption_vectors,
                        input_tensors['t_z']: z_noise,
                        input_tensors['noise_indicator']: 0,
                        input_tensors['noise_gen']: 0,
                        input_tensors['noise_disc']: 0,
                        input_tensors['gen_image1']: img3,
                        input_tensors['gen_image2']: img2,
                        input_tensors['gen_image4']: img1
                    })
                idx = np.random.randint(20, 60)
                image1 = img1[idx, :, :, :]
                image2 = img2[idx, :, :, :]
                image3 = img3[idx, :, :, :]
                trs1 = trans1[idx, :, :, :]
                trs2 = trans2[idx, :, :, :]
                trs3 = trans3[idx, :, :, :]
                trs21 = trans21[idx, :, :, :]
                trs22 = trans22[idx, :, :, :]
                trs23 = trans23[idx, :, :, :]
                real_full = real_images[idx, :, :, :]
                horrible_full = horrible_images[idx, :, :, :]
                folder = 'images_vector5/'
                ann_write = annotations[idx]
                ann_write = ann_write.replace('.', '')
                ann_write = (ann_write[:60] +
                             '_') if len(ann_write) > 60 else ann_write
                if 0:
                    scipy.misc.imsave(
                        folder + str(i) + '_' + str(batch_no) + ann_write +
                        'stage1.jpg', image1)
                    scipy.misc.imsave(
                        folder + str(i) + '_' + str(batch_no) + ann_write +
                        'stage2.jpg', image2)
                    scipy.misc.imsave(
                        folder + str(i) + '_' + str(batch_no) + '_' +
                        ann_write + 'gen.jpg', image3)
                    scipy.misc.imsave(
                        folder + str(i) + '_' + str(batch_no) + '_' +
                        ann_write + 'real.jpg', real_full)
                img_idx += 1

            print(num_disc_steps, ' disc steps so far')
            print('epoch:', i, 'batch:', batch_no)
            print('d_loss', d1_loss, d2_loss, d3_loss)
            print('g_loss', g1_loss, g2_loss, g3_loss)
            print('d_loss_noise', d1_loss_noise, d2_loss_noise, d3_loss_noise)
            print('g_loss_noise', g1_loss_noise, g2_loss_noise, g3_loss_noise)

            batch_no += 1
            if 0:  #(batch_no % args.save_every) == 0:
                #Lee commented the following line out because it crashed. No idea what it was trying to do.
                #save_for_vis(args.data_dir, real_images, gen, image_files)
                save_path = saver.save(
                    sess,
                    "Data/Models/latest_model_vector5_{}_temp.ckpt".format(
                        args.data_set))
            if (batch_no == 1 and i in args.save_epoch):
                save_path = saver.save(
                    sess,
                    "Data/ModelEval/Feat_checkpoint_epoch_{}.ckpt".format(i))