Esempio n. 1
0
def run_task(variant):

    log_dir = logger.get_snapshot_dir()
    report = HTMLReport(os.path.join(log_dir, 'report.html'),
                        images_per_row=2,
                        default_image_width=500)
    report.add_header('Simple Circle Sampling')
    report.add_text(format_dict(variant))
    report.save()

    gan = SimpleGAN(noise_size=5, tf_session=tf.Session())

    rand_theta = np.random.uniform(0, 2 * np.pi, size=(5000, 1))
    data = np.hstack([0.5 * np.cos(rand_theta), 0.5 * np.sin(rand_theta)])
    data = data + np.random.normal(scale=0.05, size=data.shape)

    report.add_image(plot_samples(data[:500, :]), 'Real data')

    generated_samples, _ = gan.sample_generator(100)
    report.add_image(plot_samples(generated_samples))

    for outer_iter in range(30):
        dloss, gloss = gan.train(
            data,
            outer_iters=variant['outer_iters'],
        )
        logger.log('Outer iteration: {}, disc loss: {}, gen loss: {}'.format(
            outer_iter, dloss, gloss))
        report.add_text(
            'Outer iteration: {}, disc loss: {}, gen loss: {}'.format(
                outer_iter, dloss, gloss))
        generated_samples, _ = gan.sample_generator(50)
        report.add_image(plot_samples(generated_samples))
        report.add_image(plot_dicriminator(gan))

        report.save()
Esempio n. 2
0
def run_task(variant):
    
    gan_configs = {
        'batch_size': 64,
        'generator_output_activation': 'tanh',
        'generator_optimizer': tf.train.RMSPropOptimizer(variant['generator_learning_rate']),
        'discriminator_optimizer': tf.train.RMSPropOptimizer(variant['discriminator_learning_rate']),
        'batch_normalize_discriminator': False,
        'batch_normalize_generator': False,
        'gan_type': 'lsgan',
    }
    
    if variant['generator_init'] == 'xavier':
        gan_configs['generator_weight_initializer'] = tf.contrib.layers.xavier_initializer()
    else:
        gan_configs['generator_weight_initializer'] = tflearn.initializations.truncated_normal(stddev=variant['generator_init'])
    
    gan = FCGAN(
        generator_output_size=2,
        discriminator_output_size=1,
        generator_layers=[200, 200],
        discriminator_layers=[128, 128],
        noise_size=5,
        tf_session=tf.Session(),
        configs=gan_configs,
    )
    
    log_dir = logger.get_snapshot_dir()
    report = HTMLReport(
        os.path.join(log_dir, 'report.html'), images_per_row=2,
        default_image_width=500
    )
    report.add_header('Simple Circle Sampling')
    report.add_text(format_dict(variant))
    report.save()
    
    rand_theta = np.random.uniform(0, 2 * np.pi, size=(5000, 1))
    data = np.hstack([0.5 * np.cos(rand_theta), 0.5 * np.sin(rand_theta)])
    data = data + np.random.normal(scale=0.05, size=data.shape)
    
    report.add_image(
        plot_samples(data[:500, :]), 'Real data'
    )
    
    
    # for outer_iter in range(30):
    #     loss = gan.train_discriminator(data, data[:, 0:1] < 0, 100)
    #     logger.log(str(loss))
        
    # report.add_image(
    #     plot_dicriminator(gan)
    # )
    # report.save()
    
    # logger.log('Now training generator')
        
    # for outer_iter in range(30):
    #     loss = gan.train_generator(np.random.randn(1000, 2), 100)
    #     logger.log(str(loss))
        
    # generated_samples, _ = gan.sample_generator(50)
    # report.add_image(
    #     plot_samples(generated_samples)
    # )
    
    for outer_iter in range(30):
        dloss, gloss = gan.train(
            data, np.ones((data.shape[0], 1)),
            outer_iters=variant['outer_iters'], generator_iters=variant['generator_iters'],
            discriminator_iters=variant['discriminator_iters']
        )
        logger.log(
            'Outer iteration: {}, disc loss: {}, gen loss: {}'.format(
                outer_iter, dloss, gloss
            )
        )
        report.add_text(
            'Outer iteration: {}, disc loss: {}, gen loss: {}'.format(
                outer_iter, dloss, gloss
            )
        )
        generated_samples, _ = gan.sample_generator(50)
        report.add_image(
            plot_samples(generated_samples)
        )
        report.add_image(
            plot_dicriminator(gan)
        )
        
        report.save()