Esempio n. 1
0
    def gan_train(max_it, it_offset):
        print("GAN iteration: " + str(max_it))
        # total_it = it_offset + max_it
        for it in range(it_offset, it_offset + max_it):
            real_ipt, y = data_pool.batch(['img', 'label'])

            if it > 30000:
                _, _ = sess.run([d_step, g_step], feed_dict={real: real_ipt})
            else:
                _, _ = sess.run([d_step, g_step2],
                                feed_dict={real: real_ipt})  #ori g loss
            if it % 10 == 0:
                summary = sess.run(merged, feed_dict={real: real_ipt})
                writer.add_summary(summary, it)
            if it % 1000 == 0:
                i = 0
                for f in fake_set:
                    sample_imgs = sess.run(f)
                    # if normalize:
                    #     for i in range(len(sample_imgs)):
                    sample_imgs = sample_imgs * 2. - 1.
                    save_dir = dir + "/sample_imgs"
                    utils.mkdir(save_dir + '/')
                    # for imgs, name in zip(sample_imgs, list_of_names):
                    my_utils.saveSampleImgs(imgs=sample_imgs,
                                            full_path=save_dir + "/" +
                                            'sample-%d-%d.jpg' % (i, it),
                                            row=8,
                                            column=8)
                    i += 1
Esempio n. 2
0
 def gan_train(max_it, it_offset):
     print("GAN iteration: " + str(max_it))
     # total_it = it_offset + max_it
     for it in range(it_offset, it_offset + max_it):
         real_ipt, y = data_pool.batch(['img', 'label'])
         # z_ipt = np.random.normal(size=[batch_size, z_dim])
         # z_ipt = np.random.normal(size=[batch_size, z_dim])
         # if it%700 ==0 and it >0:
         #     global cat_weight_init
         #     cat_weight_init = min(0.7, 1.3*cat_weight_init)
         #     print('cat weight', cat_weight_init)
         _, _ = sess.run([d_step, g_step], feed_dict={real: real_ipt})
         if it % 10 == 0:
             summary = sess.run(merged, feed_dict={real: real_ipt})
             writer.add_summary(summary, it)
         if it % 1000 == 0:
             import utils
             i = 0
             for f in fake_set:
                 sample_imgs = sess.run(f)
                 # if normalize:
                 #     for i in range(len(sample_imgs)):
                 sample_imgs = sample_imgs * 2. - 1.
                 save_dir = dir + "/sample_imgs"
                 utils.mkdir(save_dir + '/')
                 # for imgs, name in zip(sample_imgs, list_of_names):
                 my_utils.saveSampleImgs(imgs=sample_imgs,
                                         full_path=save_dir + "/" +
                                         'sample-%d-%d.jpg' % (i, it),
                                         row=8,
                                         column=8)
                 i += 1
Esempio n. 3
0
def sample_once(it):
    list_of_generators = [images_form_g1, images_form_g2, images_form_g3]  # used for sampling images
    list_of_names = ['g1-it%d.jpg' % it, 'g2-it%d.jpg' % it, 'g3-it%d.jpg' % it]
    rows = 10
    columns = 10

    sample_imgs = sess.run(list_of_generators, feed_dict={z: np.random.normal(size=[rows * columns, z_dim]),
                                                          cat_1: [one_hot_labels[0] for _ in range(rows * columns)],
                                                          cat_2: [one_hot_labels[1] for _ in range(rows * columns)],
                                                          cat_3: [one_hot_labels[2] for _ in range(rows * columns)]})
    save_dir = dir + "/sample_imgs"
    utils.mkdir(save_dir + '/')
    for imgs, name in zip(sample_imgs, list_of_names):
        my_utils.saveSampleImgs(imgs=imgs, full_path=save_dir + "/" + name, row=rows, column=columns)
Esempio n. 4
0
def save_img(list_of_generators, list_of_names):
    # list_of_generators = [images_for_1, images_for_2, images_for_tensorboard]  # used for sampling images
    # list_of_names = ['g1-it%d.jpg' % 1, 'g2-it%d.jpg' % 1, 'wgan-jpg']
    rows = 10
    columns = 10
    sample_imgs = sess.run(
        list_of_generators,
        feed_dict={z: np.random.normal(size=[rows * columns, z_dim])})
    save_dir = dir + "/sample_imgs"
    utils.mkdir(save_dir + '/')
    for imgs, name in zip(sample_imgs, list_of_names):
        my_utils.saveSampleImgs(imgs=imgs,
                                full_path=save_dir + "/" + name,
                                row=rows,
                                column=columns)
Esempio n. 5
0
def sample_once():
    import utils
    # rows = 10
    # columns = 10
    # feed = {random_z: np.random.normal(size=[rows*columns, z_dim])}
    sample_imgs = sess.run(fake)
    # if normalize:
    #     for i in range(len(sample_imgs)):
    sample_imgs = sample_imgs * 2. - 1.
    save_dir = dir + "/sample_imgs"
    utils.mkdir(save_dir + '/')
    # for imgs, name in zip(sample_imgs, list_of_names):
    my_utils.saveSampleImgs(imgs=sample_imgs,
                            full_path=save_dir + "/" + 'sample.jpg',
                            row=50,
                            column=40)
    def gan_train(max_it, it_offset):
        print("GAN iteration: " + str(max_it))

        for it in range(it_offset, it_offset + max_it):
            real_ipt, y = data_pool.batch(['img', 'label'])

            _, _ = sess.run([d_step,g_step], feed_dict={real: real_ipt, real_weight: real_weight_init})
            if it % 10 == 0:
                summary = sess.run(merged, feed_dict={real: real_ipt, real_weight: real_weight_init})
                writer.add_summary(summary, it)
            if it%1000 == 0:
                i = 0
                for f in fake_set:
                    sample_imgs = sess.run(f)

                    sample_imgs = sample_imgs * 2. - 1.
                    save_dir = dir + "/sample_imgs"
                    utils.mkdir(save_dir + '/')

                    my_utils.saveSampleImgs(imgs=sample_imgs, full_path=save_dir + "/" + 'sample-%d-%d.jpg' % (i,it), row=8,
                                            column=8)
                    i += 1
Esempio n. 7
0
            writer.add_summary(summary, it)

    var = raw_input("Continue training for %d iterations?" % max_it)
    if var.lower() == 'y':
        training(max_it, it_offset + max_it)

total_it = 0
try:
    training(max_it,0)
    total_it = sess.run(global_step)
    print("Total iterations: "+str(total_it))
except Exception, e:
    traceback.print_exc()
finally:
    var = raw_input("Save sample images?")
    if var.lower() == 'y':
        list_of_generators = [images_form_g1, images_form_g2]  # used for sampling images
        list_of_names = ['g1-it%d.jpg'%total_it,'g2-it%d.jpg'%total_it]
        rows = 10
        columns = 10
        sample_imgs = sess.run(list_of_generators, feed_dict={z: np.random.normal(size=[rows*columns, z_dim])})
        save_dir = dir + "/sample_imgs"
        utils.mkdir(save_dir + '/')
        for imgs,name in zip(sample_imgs,list_of_names):
            my_utils.saveSampleImgs(imgs=imgs, full_path=save_dir+"/"+name, row=rows,column=columns)
    # save checkpoint
    save_path = saver.save(sess, dir+"/checkpoint/model.ckpt")
    print("Model saved in path: %s" % save_path)
    print(" [*] Close main session!")
    sess.close()
Esempio n. 8
0
    except Exception, e:
        traceback.print_exc()
    finally:
        import utils
        i = 0
        for f in fake_set:
            sample_imgs = sess.run(f)
            # if normalize:
            #     for i in range(len(sample_imgs)):
            sample_imgs = sample_imgs * 2. - 1.
            save_dir = dir + "/sample_imgs"
            utils.mkdir(save_dir + '/')
            # for imgs, name in zip(sample_imgs, list_of_names):
            my_utils.saveSampleImgs(imgs=sample_imgs,
                                    full_path=save_dir + "/" +
                                    'sample-%d.jpg' % i,
                                    row=8,
                                    column=8)
            i += 1

        # save checkpoint
        save_path = saver.save(sess, dir + "/checkpoint/model.ckpt")
        print("Model saved in path: %s" % save_path)
        print(" [*] Close main session!")
        sess.close()


def run(exp):
    if exp == 'imsat-svhn':
        print('Expriment IMSAT + GAN on SVHN, one experiments')
        print('=====Exp 1=====')