Example #1
0
    def __init__(self,
                 util: Utils,
                 hr_size=96,
                 log_dir: str = None,
                 num_resblock: int = 16):
        self.vgg = self.vgg(20)
        self.learning_rate = 0.00005
        self.clipping = 0.01
        self.generator_optimizer = RMSprop(learning_rate=self.learning_rate,
                                           clipvalue=self.clipping)
        self.discriminator_optimizer = RMSprop(
            learning_rate=self.learning_rate, clipvalue=self.clipping)
        self.binary_cross_entropy = BinaryCrossentropy(from_logits=True)
        self.mean_squared_error = MeanSquaredError()
        self.util: Utils = util
        self.HR_SIZE = hr_size
        self.LR_SIZE = self.HR_SIZE // 4

        if log_dir is not None:
            self.summary_writer = tf.summary.create_file_writer(log_dir)
            if log_dir.startswith('../'):
                log_dir = log_dir[len('../'):]
            print('open tensorboard with: tensorboard --logdir ' + log_dir)

        else:
            self.summary_writer = None

        self.generator = make_generator_model(num_res_blocks=num_resblock)
        self.discriminator = make_discriminator_model(self.HR_SIZE)
        self.checkpoint = tf.train.Checkpoint(generator=self.generator,
                                              discriminator=self.discriminator)
Example #2
0
    def test(self):
        test_img = tf.expand_dims(decode_img(self.test_img_path), axis=0)
        label_trg = tf.expand_dims(tf.constant(self.attr_values, dtype=tf.float32), axis=0)

        self.generator = make_generator_model(label_dim=self.label_dim, g_conv_dim=self.g_conv_dim)
        generated_img = self.generator([test_img, label_trg])

        init = tf.global_variables_initializer()
        saver = tf.train.Saver(max_to_keep=1)

        with tf.Session() as sess:
            sess.run(init)

            # checkpoint
            latest_ckpt = tf.train.latest_checkpoint(self.ckpt_dir)
            if latest_ckpt is None:
                print('No checkpoint found!')
            else:
                print('Found checkpoint : "{}"'.format(latest_ckpt))
                saver.restore(sess, latest_ckpt)

            # generate result
            test_img_arr, generated_img_arr = sess.run([test_img, generated_img])  # (1, 128, 128, 3)
            test_img_arr = np.squeeze(test_img_arr, axis=0)  # (128, 128, 3)
            generated_img_arr = np.squeeze(generated_img_arr, axis=0)  # (128, 128, 3)

            display_test_result(test_img_arr, generated_img_arr,
                                self.selected_attributes, self.attr_values, self.test_img_path)
Example #3
0
    def build_model(self):
        self.generator = make_generator_model(label_dim=self.label_dim, g_conv_dim=self.g_conv_dim)
        self.discriminator = make_discriminator_model(label_dim=self.label_dim, d_conv_dim=self.d_conv_dim)

        self.g_lr_ph = tf.placeholder(tf.float32, name='g_learning_rate')
        self.d_lr_ph = tf.placeholder(tf.float32, name='d_learning_rate')
        self.g_optimizer = tf.train.AdamOptimizer(self.g_lr_ph, beta1=0.5, beta2=0.999)
        self.d_optimizer = tf.train.AdamOptimizer(self.d_lr_ph, beta1=0.5, beta2=0.999)
Example #4
0
def train_gan(data,
              checkpoint_dir,
              start_epoch=100,
              epochs=200,
              restart=False,
              batch_size=64,
              lr_gen=2e-4,
              lr_disc=2e-4,
              num_examples_to_generate=16):

    # Create the generator and discriminator
    generator = make_generator_model()
    discriminator = make_discriminator_model()

    generator_loss = loss_functions.generator_loss
    discriminator_loss = loss_functions.discriminator_loss

    generator_optimizer = tf.keras.optimizers.Adam(lr_gen, beta_1=0.5)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr_disc, beta_1=0.5)

    seed = tf.random.normal([num_examples_to_generate, noise_dim], seed=42)

    image_folder = os.path.join(checkpoint_dir, "Images")
    if not os.path.isdir(image_folder):
        os.mkdir(image_folder)
        if not os.path.isdir(image_folder):
            print(f"Unable to create {image_folder}. Exiting train_gan.")
            return

    # checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
    checkpoint = tf.train.Checkpoint(
        step=tf.Variable(1),
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer,
        generator=generator,
        discriminator=discriminator,
    )
    manager = tf.train.CheckpointManager(checkpoint,
                                         checkpoint_folder,
                                         max_to_keep=5)
    if restart:
        print("Initializing from scratch.")
    else:
        checkpoint.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print("Restored from {}".format(manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

    train(data,
          epochs,
          start_epoch,
          batch_size,
          generator,
          discriminator,
          generator_loss,
          discriminator_loss,
          generator_optimizer,
          discriminator_optimizer,
          checkpoint,
          manager,
          image_folder,
          seed,
          use_smoothing=True,
          use_noise=True)
Example #5
0
    image = cv.imread(filename)
    all_images.append(image)

X = np.array(all_images)
train_images = X.reshape(X.shape[0], 64, 64, 3).astype('float32')
train_images = (train_images -
                127.5) / 127.5  # Normalize the images to [-1, 1]

BUFFER_SIZE = 500
BATCH_SIZE = 128
inputshape = 10
#batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

generator = make_generator_model(inputshape)
discriminator = make_discriminator_model()

print(generator.summary())
print(discriminator.summary())

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
def _main(args):
    ##############################
    # Distribute strategy
    ##############################
    gpus = tf.config.experimental.list_physical_devices('GPU')
    assert len(gpus) > 0, "No GPUs available."
    tf.config.experimental.set_visible_devices(
        gpus[-args.n_gpus:], 'GPU')  # use only the last n gpus
    strategy = tf.distribute.MirroredStrategy(
    )  # create a mirrored strategy for data parallel

    ##############################
    # Constants
    ##############################
    DATA_DIR = os.path.join('data', args.dataset_name)
    IMG_PATH_PATTERN = os.path.join(DATA_DIR, '*.*')
    CKPT_DIR = os.path.join('model', args.dataset_name + '_parallel')

    N_SAMPLES = len(glob(IMG_PATH_PATTERN))
    GLOBAL_BATCH_SIZE = args.batch_size_per_replica * strategy.num_replicas_in_sync

    SAMPLE_DIR = os.path.join('samples', args.dataset_name + '_parallel')
    #     SEED = tf.random.normal(shape=(NUM_EXAMPLES_TO_GENERATE, 1, 1, NOISE_DIM))
    SEED = tf.random.truncated_normal(shape=(NUM_EXAMPLES_TO_GENERATE, 1, 1,
                                             NOISE_DIM),
                                      stddev=0.5)

    ##############################
    # Create directories
    ##############################
    assert os.path.exists(DATA_DIR)
    if not os.path.exists(CKPT_DIR):
        os.makedirs(CKPT_DIR)
    if not os.path.exists(SAMPLE_DIR):
        os.makedirs(SAMPLE_DIR)

    ##############################
    # Prepare dataset
    ##############################
    def decode_img(img_path):
        img_raw = tf.io.read_file(img_path)
        img = tf.image.decode_jpeg(img_raw, channels=3)
        img = tf.image.resize(img, [128, 128])  # resize image
        img = (tf.cast(img, tf.float32) -
               127.5) / 127.5  # Normalize the images to [-1, 1]
        return img

    dataset = tf.data.Dataset.list_files(IMG_PATH_PATTERN)
    dataset = dataset.map(decode_img,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(BUFFER_SIZE).batch(
        GLOBAL_BATCH_SIZE,
        drop_remainder=True).repeat().prefetch(tf.data.experimental.AUTOTUNE)
    dist_dataset = strategy.experimental_distribute_dataset(dataset)

    with strategy.scope():
        iterator = dist_dataset.make_initializable_iterator()
        iterator_init = iterator.initialize()
        image_ph = iterator.get_next()

    ##############################
    # Create models and optimizers
    ##############################
    with strategy.scope():
        generator = make_generator_model(NOISE_DIM)
        discriminator = make_discriminator_model()

        generator_optimizer = tf.keras.optimizers.Adam(
            args.g_learning_rate,
            beta_1=args.beta_1,
            beta_2=args.beta_2,
            epsilon=args.adam_epsilon)
        discriminator_optimizer = tf.keras.optimizers.Adam(
            args.d_learning_rate,
            beta_1=args.beta_1,
            beta_2=args.beta_2,
            epsilon=args.adam_epsilon)

    ##############################
    # Train step
    ##############################
    with strategy.scope():
        hinge_loss = tf.keras.losses.Hinge(reduction='none')

        def discriminator_real_loss(real_output):
            per_sample_loss = hinge_loss(
                tf.ones_like(real_output),
                real_output)  # (batch_size_per_replica,)
            loss = tf.reduce_sum(per_sample_loss) / tf.cast(
                GLOBAL_BATCH_SIZE, per_sample_loss.dtype)  # scalar
            return loss

        def discriminator_fake_loss(fake_output):
            per_sample_loss = hinge_loss(
                -tf.ones_like(fake_output),
                fake_output)  # (batch_size_per_replica,)
            loss = tf.reduce_sum(per_sample_loss) / tf.cast(
                GLOBAL_BATCH_SIZE, per_sample_loss.dtype)  # scalar
            return loss

        def generator_loss(fake_output):
            return -tf.reduce_sum(fake_output) / tf.cast(
                GLOBAL_BATCH_SIZE, fake_output.dtype)  # scalar

        def train_step(inputs):
            '''
            inputs: "per-replica" values, such as those produced by a "distributed Dataset"
            '''
            with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
                #                 noise = tf.random.normal(shape=(args.batch_size_per_replica, 1, 1, NOISE_DIM))
                noise = tf.random.truncated_normal(
                    shape=(args.batch_size_per_replica, 1, 1, NOISE_DIM),
                    stddev=0.5)
                real_output = discriminator(inputs,
                                            training=TrainArg.TRUE_UPDATE_U)
                fake_output = discriminator(generator(
                    noise, training=TrainArg.TRUE_UPDATE_U),
                                            training=TrainArg.TRUE_NO_UPDATE_U)

                disc_loss_op = discriminator_real_loss(
                    real_output) + discriminator_fake_loss(fake_output)
                gen_loss_op = generator_loss(fake_output)

            disc_grads = disc_tape.gradient(disc_loss_op,
                                            discriminator.trainable_variables)
            disc_train_op = discriminator_optimizer.apply_gradients(
                zip(disc_grads, discriminator.trainable_variables))
            update_ops = generator.get_updates_for(
                generator.inputs
            )  # to update moving_mean and moving_variance of BatchNormalization layers
            disc_train_op = tf.group([disc_train_op, update_ops])

            # make sure `loss`es will only be returned after `train_op`s have executed
            with tf.control_dependencies([disc_train_op]):
                disc_loss_op_id = tf.identity(disc_loss_op)

            gen_grads = gen_tape.gradient(gen_loss_op,
                                          generator.trainable_variables)
            gen_train_op = generator_optimizer.apply_gradients(
                zip(gen_grads, generator.trainable_variables))
            gen_train_op = tf.group([gen_train_op, update_ops])

            # make sure `loss`es will only be returned after `train_op`s have executed
            with tf.control_dependencies([gen_train_op]):
                gen_loss_op_id = tf.identity(gen_loss_op)

            return disc_loss_op_id, gen_loss_op_id

    ##############################
    # Training loop
    ##############################
    with strategy.scope():

        def distributed_train_step(dataset_inputs):
            per_replica_disc_losses, per_replica_gen_losses = strategy.experimental_run_v2(
                train_step, args=(dataset_inputs, ))
            mean_disc_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                             per_replica_disc_losses)
            mean_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                            per_replica_gen_losses)
            return mean_disc_loss, mean_gen_loss

        def generate_and_save_images(sess, index):
            test_fake_image = generator(SEED, training=TrainArg.FALSE)
            predictions = sess.run(test_fake_image)
            fig = plt.figure(figsize=(8, 8))

            for i in range(NUM_EXAMPLES_TO_GENERATE):
                plt.subplot(5, 5, i + 1)
                plt.imshow((predictions[i] * 127.5 + 127.5).astype(int))
                plt.axis('off')

            plt.savefig(
                os.path.join(SAMPLE_DIR, 'iter-{}.jpg'.format(str(index))))

        mean_disc_loss_op, mean_gen_loss_op = distributed_train_step(image_ph)
        init = tf.global_variables_initializer()
        saver = tf.train.Saver(max_to_keep=2)

        with tf.Session() as sess:
            sess.run(init)
            sess.run(iterator_init)

            latest_ckpt = tf.train.latest_checkpoint(CKPT_DIR)
            if latest_ckpt is None:
                print('No checkpoint found!')
            else:
                print('Found checkpoint : "{}"'.format(latest_ckpt))
                saver.restore(sess, latest_ckpt)
                START_ITER = int(latest_ckpt.split('-')[-1])

            for iteration in range(START_ITER, args.n_iters):
                start = time.time()

                disc_loss = sess.run(mean_disc_loss_op)
                gen_loss = sess.run(mean_gen_loss_op)

                if (iteration + 1) % args.display_frequency == 0:
                    generate_and_save_images(sess, iteration + 1)

                print('discriminator loss: {}'.format(disc_loss))
                print('generator loss: {}'.format(gen_loss))
                print('Time for iteration {}/{} is {} sec'.format(
                    iteration + 1, args.n_iters,
                    time.time() - start))
                print('#############################################')

                if (iteration + 1) % args.save_frequency == 0:
                    saver.save(sess,
                               os.path.join(CKPT_DIR, 'model'),
                               global_step=iteration + 1)
Example #7
0
# Construct a tf.data.Dataset
(train_images, train_labels), (test_images,
                               test_labels) = datasets.mnist.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0


# Reshape training and testing image
train_images = train_images.reshape(-1, 28, 28, 1)
test_images = test_images.reshape(-1, 28, 28, 1)


# get the generator and classifier
classifier = make_classifier_model()
generator = make_generator_model()

class_optimizer = tf.keras.optimizers.Adam(1e-4)
gen_optimizer = tf.keras.optimizers.Adam(1e-4)

BATCH_SIZE = 256
BUFFER_SIZE = 100
num_examples_to_generate = 16
noise_dim = 100
seed = tf.random.normal([num_examples_to_generate, noise_dim])

train_data = tf.data.Dataset.from_tensor_slices(
    train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)


def generate_and_save_images(model, epoch, test_input):
Example #8
0
else:
    color = 1

train_images = train_images.reshape(img_count, width, height,
                                    color).astype('float32')
train_images = (train_images -
                127.5) / 127.5  # Normalize the images to [-1, 1]

BUFFER_SIZE = img_count
BATCH_SIZE = 256

# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

generator = make_generator_model(width, height, color)
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

discriminator = make_discriminator_model(width, height, color)
decision = discriminator(generated_image)

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(