Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('image_dir', type=str)
    parser.add_argument('--batch_size', '-bs', type=int, default=64)
    parser.add_argument('--nb_epoch', '-e', type=int, default=1000)
    parser.add_argument('--noise_dim', '-nd', type=int, default=128)
    parser.add_argument('--height', '-ht', type=int, default=128)
    parser.add_argument('--width', '-wd', type=int, default=128)
    parser.add_argument('--save_steps', '-ss', type=int, default=10)
    parser.add_argument('--visualize_steps', '-vs', type=int, default=10)
    parser.add_argument('--lambda', '-l', type=float, default=10., dest='lmbd')
    parser.add_argument('--initial_steps', '-is', type=int, default=20)
    parser.add_argument('--initial_critics', '-sc', type=int, default=20)
    parser.add_argument('--normal_critics', '-nc', type=int, default=5)
    parser.add_argument('--model_dir', '-md', type=str, default="./params")
    parser.add_argument('--result_dir', '-rd', type=str, default="./result")
    parser.add_argument('--noise_mode', '-nm', type=str, default="uniform")
    parser.add_argument('--upsampling', '-up', type=str, default="subpixel")
    parser.add_argument('--dis_norm', '-dn', type=str, default=None)

    args = parser.parse_args()

    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)

    # output config to csv
    config_path = os.path.join(args.result_dir, "config.csv")
    dict_ = vars(args)
    df = pd.DataFrame(list(dict_.items()), columns=['attr', 'status'])
    df.to_csv(config_path, index=None)

    input_shape = (args.height, args.width, 3)

    image_sampler = ImageSampler(args.image_dir, target_size=input_shape[:2])
    noise_sampler = NoiseSampler(args.noise_mode)

    generator = Generator(args.noise_dim,
                          is_training=True,
                          upsampling=args.upsampling)
    discriminator = Discriminator(input_shape,
                                  is_training=True,
                                  normalization=args.dis_norm)

    wgan = WGAN(generator, discriminator, lambda_=args.lmbd, is_training=True)

    wgan.fit(image_sampler.flow(args.batch_size),
             noise_sampler,
             nb_epoch=args.nb_epoch,
             result_dir=args.result_dir,
             model_dir=args.model_dir,
             save_steps=args.save_steps,
             visualize_steps=args.visualize_steps,
             initial_steps=args.initial_steps,
             initial_critics=args.initial_critics,
             normal_critics=args.normal_critics)
Exemplo n.º 2
0
def load_model(g_dir, d_dir, latent_dim):
    g_name = g_dir.split('/')[-1]

    n_blocks = g_name.split('-')[-1].split(".")[-2]
    cur_block = g_name.split('-')[-2]

    cus = {
        'WeightedSum': WeightedSum,
        'PixelNormalization': PixelNormalization,
        'MinibatchStdev': MinibatchStdev,
        'Conv2DEQ': Conv2DEQ,
        'DenseEQ': DenseEQ
    }

    g_model = Generator(latent_dim)
    d_model = Discriminator()

    g_model.model = models.load_model(g_dir, custom_objects=cus, compile=False)
    d_model.model = models.load_model(d_dir, custom_objects=cus, compile=False)

    wgan = WGAN(discriminator=d_model,
                generator=g_model,
                latent_dim=latent_dim,
                d_train=True,
                discriminator_extra_steps=1)

    return wgan, n_blocks, cur_block
Exemplo n.º 3
0
    def __init__(self, flags):
        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=run_config)

        self.flags = flags
        self.dataset = Dataset(self.sess, flags, self.flags.dataset)
        self.model = WGAN(self.sess, self.flags, self.dataset)

        self._make_folders()
        self.iter_time = 0

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())

        tf_utils.show_all_variables()
def main():

    for i in range(0, 1000):
        a = np.random.randint(10, 41)
        b = np.random.randint(a, 70)
        c = np.random.randint(10, b)

        args = {
            'attack_type': "smurf",
            'max_epochs': 7000,
            'batch_size': 255,
            'sample_size': 500,
            'optimizer_learning_rate': 0.001,
            'generator_layers': [a, b, c]
        }
        for iter in range(0, 10):
            gan = WGAN(**args)
            gan.train()
            print("GAN finished with layers:")
            print(str([a, b, c]))
Exemplo n.º 5
0
def main():

    dataset = PairDataset(
        first_dir='',
        second_dir='',
        num_samples=NUM_STEPS * BATCH_SIZE,
        image_size=IMAGE_SIZE
    )
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=BATCH_SIZE, shuffle=True,
        num_workers=1, pin_memory=True
    )
    gan = WGAN(IMAGE_SIZE)

    logs = []
    text = 'i: {0}, content: {1:.3f}, tv: {2:.5f}, ' +\
           'realism: {3:.3f}, discriminator: {4:.3f}'

    for i, (x, y) in enumerate(data_loader, 1):

        x = x.cuda()
        y = y.cuda()

        update_generator = i % N_DISCRIMINATOR == 0
        losses = gan.train_step(x, y, update_generator)

        log = text.format(
            i, losses['content'], losses['tv'],
            losses['realism_generation'], losses['discriminator']
        )
        print(log)
        logs.append(losses)

        if i % SAVE_STEP == 0:
            gan.save_model(MODEL_SAVE_PREFIX)
            with open(TRAIN_LOGS, 'w') as f:
                json.dump(logs, f)
Exemplo n.º 6
0
def gan_repository(sess, flags, dataset):
    if flags.gan_model == 'vanilla_gan':
        print('Initializing Vanilla GAN...')
        return GAN(sess, flags, dataset.image_size)
    elif flags.gan_model == 'dcgan':
        print('Initializing DCGAN...')
        return DCGAN(sess, flags, dataset.image_size)
    elif flags.gan_model == 'pix2pix':
        print('Initializing pix2pix...')
        return Pix2Pix(sess, flags, dataset.image_size)
    elif flags.gan_model == 'pix2pix-patch':
        print('Initializing pix2pix-patch...')
        return Pix2PixPatch(sess, flags, dataset.image_size)
    elif flags.gan_model == 'wgan':
        print('Initializing WGAN...')
        return WGAN(sess, flags, dataset)
    elif flags.gan_model == 'cyclegan':
        print('Initializing cyclegan...')
        return CycleGAN(sess, flags, dataset.image_size, dataset())
    elif flags.gan_model == 'mrigan':
        print('Initializing mrigan...')
        return MRIGAN(sess, flags, dataset.image_size, dataset())
    elif flags.gan_model == 'mrigan02':
        print('Initializing mrigan02...')
        return MRIGAN02(sess, flags, dataset.image_size, dataset())
    elif flags.gan_model == 'mrigan03':
        print('Initializing mrigan03...')
        return MRIGAN03(sess, flags, dataset.image_size, dataset())
    elif flags.gan_model == 'mrigan01_lsgan':
        print('Initializing mrigan01_lsgan...')
        return MRIGAN01_LSGAN(sess, flags, dataset.image_size, dataset())
    elif flags.gan_model == 'mrigan02_lsgan':
        print('Initializing mrigan02_lsgan...')
        return MRIGAN02_LSGAN(sess, flags, dataset.image_size, dataset())
    elif flags.gan_model == 'mrigan03_lsgan':
        print('Initializing mrigan03_lsgan...')
        return MRIGAN03_LSGAN(sess, flags, dataset.image_size, dataset())
    elif flags.gan_model == 'mrigan_01':
        print('Initializing mrigan_01...')
        return MRIGAN_01(sess, flags, dataset.image_size, dataset())
    elif flags.gan_model == 'mrigan_02':
        print('Initializing mrigan_02...')
        return MRIGAN_02(sess, flags, dataset.image_size, dataset())
    else:
        raise NotImplementedError
import sys

sys.path.append("..")
import utils
from wgan import WGAN

dataset, _, timesteps = utils.load_splitted_dataset()
# dataset, _, timesteps = utils.load_resized_mnist()

clip_value = 0.01

run_dir, img_dir, model_dir, generated_datesets_dir = utils.generate_run_dir()

config_2 = {
    'timesteps': timesteps,
    'run_dir': run_dir,
    'img_dir': img_dir,
    'model_dir': model_dir,
    'generated_datesets_dir': generated_datesets_dir,
    'clip_value': clip_value
}

config = utils.merge_config_and_save(config_2)

wgan = WGAN(config)
losses = wgan.train(dataset)
Exemplo n.º 8
0
from wgan import WGAN

if __name__ == '__main__':
    # get input data
    mnist_data = input_data.load_mnist_dataset('../../dataset/mnist_data',
                                               one_hot=True)
    num_sample = mnist_data.train.num_examples
    dataset = 'mnist'
    if dataset == 'mnist':
        input_dim = 784

    # define latent dimension
    z_dim = 100

    num_epoch = 100000
    batch_size = 100

    # Launch the session
    with tf.Session() as sess:
        gan = WGAN(sess,
                   num_epoch=num_epoch,
                   batch_size=batch_size,
                   dataset=dataset,
                   input_dim=input_dim,
                   z_dim=z_dim)

        # build generative adversarial network
        gan.build_net()

        # train the model
        gan.train(mnist_data.train, num_sample)
Exemplo n.º 9
0
    print("Number of nonfinite critic outputs: {}/200".format(
        np.sum(np.logical_not(np.isfinite(critic_outputs)))))

    print("Testing the WGAN's generator and critic...")
    print("Integral of real_hist is {}".format(
        (real_hist * (8.0 / 20.0)).sum()))
    print("Integral of fake_hist is {}".format(
        (fake_hist * (8.0 / 20.0)).sum()))

    print("Real distribution:")
    draw_density_histogram(real_hist, 8.0 / 20.0, max_height=0.3)
    print("Generated distribution:")
    draw_density_histogram(fake_hist, 8.0 / 20.0, max_height=0.3)
    print("Critic's histogram (low==fake, high==real):")
    draw_histogram(critic_hist, max_height=1.0, num_rows=5)


wgan = WGAN([1], generator_model, critic_model, gradient_penalty_factor=0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()
             )  # TODO: Initialize only WGAN variables.

    with sess.as_default():

        for i in range(10000):
            wgan.train(50, sample_real_data(20))

            if i % 1000 == 0:
                test_wgan(wgan)
Exemplo n.º 10
0
class Solver(object):
    def __init__(self, flags):
        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=run_config)

        self.flags = flags
        self.dataset = Dataset(self.sess, flags, self.flags.dataset)
        self.model = WGAN(self.sess, self.flags, self.dataset)

        self._make_folders()
        self.iter_time = 0

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())

        tf_utils.show_all_variables()

    def _make_folders(self):
        if self.flags.is_train:
            if self.flags.load_model is None:
                cur_time = datetime.now().strftime("%Y%m%d-%H%M")
                self.model_out_dir = "{}/model/{}".format(
                    self.flags.dataset, cur_time)
                if not os.path.isdir(self.model_out_dir):
                    os.makedirs(self.model_out_dir)
            else:
                cur_time = self.flags.load_model
                self.model_out_dir = "{}/model/{}".format(
                    self.flags.dataset, cur_time)

            self.sample_out_dir = "{}/sample/{}".format(
                self.flags.dataset, cur_time)
            if not os.path.isdir(self.sample_out_dir):
                os.makedirs(self.sample_out_dir)

            self.train_writer = tf.summary.FileWriter(
                "{}/logs/{}".format(self.flags.dataset, cur_time),
                graph_def=self.sess.graph_def)
        elif not self.flags.is_train:
            self.model_out_dir = "{}/model/{}".format(self.flags.dataset,
                                                      self.flags.load_model)
            self.test_out_dir = "{}/test/{}".format(self.flags.dataset,
                                                    self.flags.load_model)
            if not os.path.isdir(self.test_out_dir):
                os.makedirs(self.test_out_dir)

    def train(self):
        # load initialized checkpoint that provided
        if self.flags.load_model is not None:
            if self.load_model():
                print(' [*] Load SUCCESS!\n')
            else:
                print(' [! Load Failed...\n')

        while self.iter_time < self.flags.iters:
            # samppling images and save them
            self.sample(self.iter_time)

            # train_step
            loss, summary = self.model.train_step()
            self.model.print_info(loss, self.iter_time)
            self.train_writer.add_summary(summary, self.iter_time)
            self.train_writer.flush()

            # save model
            self.save_model(self.iter_time)
            self.iter_time += 1

        self.save_model(self.flags.iters)

    def test(self):
        if self.load_model():
            print(' [*] Load SUCCESS!')
        else:
            print(' [!] Load Failed...')

        num_iters = 20
        total_time = 0.
        for iter_time in range(num_iters):
            print('iter_time: {}'.format(iter_time))

            # measure inference time
            start_time = time.time()
            imgs = self.model.sample_imgs()  # inference
            total_time += time.time() - start_time
            self.model.plots(imgs, iter_time, self.test_out_dir)

        print('Avg PT: {:.2f} msec.'.format(total_time / num_iters * 1000.))

    def sample(self, iter_time):
        if np.mod(iter_time, self.flags.sample_freq) == 0:
            imgs = self.model.sample_imgs()
            self.model.plots(imgs, iter_time, self.sample_out_dir)

    def save_model(self, iter_time):
        if np.mod(iter_time + 1, self.flags.save_freq) == 0:
            model_name = 'model'
            self.saver.save(self.sess,
                            os.path.join(self.model_out_dir, model_name),
                            global_step=iter_time)

            print('=====================================')
            print('             Model saved!            ')
            print('=====================================\n')

    def load_model(self):
        print(' [*] Reading checkpoint...')

        ckpt = tf.train.get_checkpoint_state(self.model_out_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess,
                               os.path.join(self.model_out_dir, ckpt_name))

            meta_graph_path = ckpt.model_checkpoint_path + '.meta'
            self.iter_time = int(meta_graph_path.split('-')[-1].split('.')[0])

            print('===========================')
            print('   iter_time: {}'.format(self.iter_time))
            print('===========================')
            return True
        else:
            return False
Exemplo n.º 11
0
    按照 r * c 的格式绘图, samples 的形状为 (r * c,) + wgan.img_shape
    :return:
    """
    assert samples.shape[0] == r * c

    # 变换到 [0, 255] 上
    samples = np.round((samples + 1) / 2 * 255).astype('uint8')

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(samples[cnt, :, :, :])
            axs[i, j].axis('off')
            cnt += 1
    # fig.savefig(os.path.join(save_dir, 'img.png'))
    plt.show()


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print('Format: python plot.py <epoch index>')
        sys.exit(0)
    epoch = int(sys.argv[1])
    wgan = WGAN()
    wgan.load_model(epoch)
    r = 3
    c = 3
    samples = wgan.choose_best_generated_images(100, r * c)
    plot(samples, r, c)
import sys
import scipy.misc
import numpy as np
from vanilla_gan import Vanilla_GAN
from dcgan import DCGAN
from cgan import CGAN
from infogan import InfoGAN
from wgan import WGAN
import tensorflow as tf

if __name__ == '__main__':
    model_name = sys.argv[1]
    dataset = sys.argv[2]
    with tf.Session() as sess:
        if model_name == 'cgan':
            model = CGAN(sess, dataset)
        elif model_name == 'vanilla_gan':
            model = Vanilla_GAN(sess, dataset)
        elif model_name == 'dcgan':
            model = DCGAN(sess, dataset)
        elif model_name == 'infogan':
            model = InfoGAN(sess, dataset)
        elif model_name == 'wgan':
            model = WGAN(sess, dataset)
        else:
            print("We cannot find this model")

        model.train()

        print("finish to train dcgan")
Exemplo n.º 13
0
    n_blocks = int(args.n_blocks)

    # 4x, 8x, 16x, 32x, 64x, 128x, 256x, 512x, 1024x
    n_batch = [128, 128, 64, 32, 6, 4, 2, 1, 1]
    n_epochs = [8, 10, 10, 10, 10, 15, 15, 15, 20]

    if mode == 'train':
        # size of the latent space
        latent_dim = 512

        # define base model
        d_base = Discriminator()
        g_base = Generator(latent_dim)
        wgan = WGAN(discriminator=d_base,
                    generator=g_base,
                    latent_dim=latent_dim,
                    d_train=True,
                    discriminator_extra_steps=1)

        # prepare image generator
        real_gen = ImageDataGenerator(rescale=None, preprocessing_function=pre)

        # train model
        train(wgan, latent_dim, n_epochs, n_epochs, n_batch, n_blocks,
              real_gen, DATA_DIR, SAVE_DIR, args.dynamic_resize)

    elif mode == 'resume':
        g_model_dir = args.g_model_path
        d_model_dir = args.d_model_path

        # size of the latent space
Exemplo n.º 14
0
from pytorch_lightning import Trainer

from discriminator import Discriminator
from faces_data_module import FacesDataModule
from generator import Generator
from wgan import WGAN

if __name__ == "__main__":
    data_module = FacesDataModule()
    wgan = WGAN(generator=Generator(), discriminator=Discriminator())

    trainer = Trainer(automatic_optimization=False)
    trainer.fit(wgan, data_module)