Exemple #1
0
    def _create_graph(self, batch_size):
        # Generate some random data.
        images_data = np.random.randn(batch_size, 784).astype(np.float32)
        dataset = tf.data.Dataset.from_tensors(images_data)
        images = dataset.repeat().make_one_shot_iterator().get_next()

        # Create the models and optimizers
        generator = mnist.Generator(data_format())
        discriminator = mnist.Discriminator(data_format())
        with tf.variable_scope('generator'):
            generator_optimizer = tf.train.AdamOptimizer(0.001)
        with tf.variable_scope('discriminator'):
            discriminator_optimizer = tf.train.AdamOptimizer(0.001)

        # Run models and compute loss
        noise_placeholder = tf.placeholder(tf.float32,
                                           shape=[batch_size, NOISE_DIM])
        generated_images = generator(noise_placeholder)
        tf.contrib.summary.image('generated_images',
                                 tf.reshape(generated_images, [-1, 28, 28, 1]),
                                 max_images=10)
        discriminator_gen_outputs = discriminator(generated_images)
        discriminator_real_outputs = discriminator(images)
        generator_loss = mnist.generator_loss(discriminator_gen_outputs)
        discriminator_loss = mnist.discriminator_loss(
            discriminator_real_outputs, discriminator_gen_outputs)
        # Get train ops
        with tf.variable_scope('generator'):
            generator_train = generator_optimizer.minimize(
                generator_loss, var_list=generator.variables)
        with tf.variable_scope('discriminator'):
            discriminator_train = discriminator_optimizer.minimize(
                discriminator_loss, var_list=discriminator.variables)

        return (generator_train, discriminator_train, noise_placeholder)
    def benchmark_train(self):
        for batch_size in [64, 128, 256]:
            # Generate some random data.
            burn_batches, measure_batches = (3, 100)
            burn_images = [
                tf.random_normal([batch_size, 784])
                for _ in range(burn_batches)
            ]
            burn_dataset = tf.data.Dataset.from_tensor_slices(burn_images)
            measure_images = [
                tf.random_normal([batch_size, 784])
                for _ in range(measure_batches)
            ]
            measure_dataset = tf.data.Dataset.from_tensor_slices(
                measure_images)

            step_counter = tf.train.get_or_create_global_step()
            with tf.device(device()):
                # Create the models and optimizers
                generator = mnist.Generator(data_format())
                discriminator = mnist.Discriminator(data_format())
                with tf.variable_scope('generator'):
                    generator_optimizer = tf.compat.v1.train.AdamOptimizer(
                        0.001)
                with tf.variable_scope('discriminator'):
                    discriminator_optimizer = tf.compat.v1.train.AdamOptimizer(
                        0.001)

                with tf.contrib.summary.create_file_writer(
                        tempfile.mkdtemp(),
                        flush_millis=SUMMARY_FLUSH_MS).as_default():

                    # warm up
                    mnist.train_one_epoch(generator,
                                          discriminator,
                                          generator_optimizer,
                                          discriminator_optimizer,
                                          burn_dataset,
                                          step_counter,
                                          log_interval=SUMMARY_INTERVAL,
                                          noise_dim=NOISE_DIM)
                    # measure
                    start = time.time()
                    mnist.train_one_epoch(generator,
                                          discriminator,
                                          generator_optimizer,
                                          discriminator_optimizer,
                                          measure_dataset,
                                          step_counter,
                                          log_interval=SUMMARY_INTERVAL,
                                          noise_dim=NOISE_DIM)
                    self._report('train', start, measure_batches, batch_size)