Beispiel #1
0
def main():
    generator = make_generator_model()
    discriminator = make_discriminator_model()
    generator_optimizer = tf.keras.optimizers.Adam(1e-4)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

    checkpoint_dir = './checkpoints'
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     generator=generator,
                                     discriminator=discriminator)
    manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
    checkpoint.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    noise = tf.random.normal([1, 100])
    generated_image = generator(noise, training=False)

    plt.imshow(generated_image[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
    plt.show()
Beispiel #2
0
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.99
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)

mb_size = 16
img_size = 256
in_depth = args.depth
disc_iters, gene_iters = args.itd, args.itg
lambda_mse, lambda_adv, lambda_perc = args.lmse, args.ladv, args.lperc

res_dir = 'videos/' + args.resfolder

generator = make_generator_model(input_shape=(None, None, in_depth),
                                 nlayers=args.lunet)
discriminator = make_discriminator_model(input_shape=(img_size, img_size, 1))

# input range should be [0, 255]
feature_extractor_vgg = tf.keras.applications.VGG19(\
                        weights='vgg19_weights_notop.h5', \
                        include_top=False)

time_dit_st = time.time()

generator.load_weights(args.weights)
samples = ['s1', 's2', 's3']
p_noisy = ['10', '15', '20',
           '25']  # # 0% is GT (clean) and the rest are noisy inputs
frame_count = 279  # number of frames per case
denoise_frames = []
noisy_frames = []
Beispiel #3
0
def main():
    (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
    train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

    # Batch and shuffle the data
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    generator = make_generator_model()
    discriminator = make_discriminator_model()

    generator_optimizer = tf.keras.optimizers.Adam(1e-4)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

    checkpoint_dir = './checkpoints'
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     generator=generator,
                                     discriminator=discriminator,
                                     step=tf.Variable(0))
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)

    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    # Try to recognize real ones as real, and fake ones as fake
    def calc_discriminator_loss(real_output, fake_output):
        real_loss = cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
        total_loss = real_loss + fake_loss
        return total_loss

    # Try to make fake ones look real
    def calc_generator_loss(fake_output):
        return cross_entropy(tf.ones_like(fake_output), fake_output)

    @tf.function
    def train_step(images):
        noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])

        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
            generated_image = generator(noise, training=True)
            real_output = discriminator(images, training=True)
            fake_output = discriminator(generated_image, training=True)

            generator_loss = calc_generator_loss(fake_output)
            discriminator_loss = calc_discriminator_loss(real_output, fake_output)

        generator_gradients = generator_tape.gradient(generator_loss, generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

    def train(dataset, epochs):
        for epoch in range(1, epochs+1):
            start = time.time()

            for image_batch in dataset:
                train_step(image_batch)

            checkpoint.step.assign_add(1)
            if epoch % 2 == 0:
                save_path = checkpoint_manager.save()
                print("Saved checkpoint for step {}: {}".format(int(checkpoint.step), save_path))

            print('Time for epoch {} is {} sec'.format(epoch, time.time() - start))

    train(train_dataset, EPOCHS)
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser(
        description='Convert TFRecords for CycleGAN dataset.')
    parser.add_argument(
        '--dataset', help='The name of the dataset', required=True)
    parser.add_argument(
        '--batch_size', help='The batch size of input data', default='4')
    args = parser.parse_args()

    loss_gen_total_metrics = tf.keras.metrics.Mean(
        'loss_gen_total_metrics', dtype=tf.float32)
    loss_dis_total_metrics = tf.keras.metrics.Mean(
        'loss_dis_total_metrics', dtype=tf.float32)
    loss_cycle_a2b2a_metrics = tf.keras.metrics.Mean(
        'loss_cycle_a2b2a_metrics', dtype=tf.float32)
    loss_cycle_b2a2b_metrics = tf.keras.metrics.Mean(
        'loss_cycle_b2a2b_metrics', dtype=tf.float32)
    loss_gen_a2b_metrics = tf.keras.metrics.Mean(
        'loss_gen_a2b_metrics', dtype=tf.float32)
    loss_gen_b2a_metrics = tf.keras.metrics.Mean(
        'loss_gen_b2a_metrics', dtype=tf.float32)
    loss_dis_b_metrics = tf.keras.metrics.Mean(
        'loss_dis_b_metrics', dtype=tf.float32)
    loss_dis_a_metrics = tf.keras.metrics.Mean(
        'loss_dis_a_metrics', dtype=tf.float32)
    loss_id_b2a_metrics = tf.keras.metrics.Mean(
        'loss_id_b2a_metrics', dtype=tf.float32)
    loss_id_a2b_metrics = tf.keras.metrics.Mean(
        'loss_id_a2b_metrics', dtype=tf.float32)
    mse_loss = tf.keras.losses.MeanSquaredError()
    mae_loss = tf.keras.losses.MeanAbsoluteError()
    fake_pool_b2a = ImagePool(POOL_SIZE)
    fake_pool_a2b = ImagePool(POOL_SIZE)

    def calc_gan_loss(prediction, is_real):
        # Typical GAN loss to set objectives for generator and discriminator
        if is_real:
            return mse_loss(prediction, tf.ones_like(prediction))
        else:
            return mse_loss(prediction, tf.zeros_like(prediction))

    def calc_cycle_loss(reconstructed_images, real_images):
        # Cycle loss to make sure reconstructed image looks real
        return mae_loss(reconstructed_images, real_images)

    def calc_identity_loss(identity_images, real_images):
        # Identity loss to make sure generator won't do unnecessary change
        # Ideally, feeding a real image to generator should generate itself
        return mae_loss(identity_images, real_images)

    def make_dataset(filepath):
        raw_dataset = tf.data.TFRecordDataset(filepath)

        image_feature_description = {
            'image/height': tf.io.FixedLenFeature([], tf.int64),
            'image/width': tf.io.FixedLenFeature([], tf.int64),
            'image/format': tf.io.FixedLenFeature([], tf.string),
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
        }

        def preprocess_image(encoded_image):
            image = tf.image.decode_jpeg(encoded_image, 3)
            # random flip left or right
            image = tf.image.random_flip_left_right(image)
            # resize to 286x286
            image = tf.image.resize(image, [286, 286])
            # random crop a 256x256 area
            image = tf.image.random_crop(
                image, [256, 256, tf.shape(image)[-1]])
            # normalize from 0-255 to -1 ~ +1
            image = image / 127.5 - 1
            return image

        def parse_image_function(example_proto):
            # Parse the input tf.Example proto using the dictionary above.
            features = tf.io.parse_single_example(example_proto,
                                                  image_feature_description)
            encoded_image = features['image/encoded']
            image = preprocess_image(encoded_image)
            return image

        parsed_image_dataset = raw_dataset.map(parse_image_function)
        return parsed_image_dataset

    def count_dataset_batches(dataset):
        size = 0
        for _ in dataset:
            size += 1
        return size

    train_a = make_dataset('tfrecords/{}/trainA.tfrecord'.format(args.dataset))
    train_b = make_dataset('tfrecords/{}/trainB.tfrecord'.format(args.dataset))
    combined_dataset = tf.data.Dataset.zip(
        (train_a, train_b)).shuffle(SHUFFLE_SIZE).batch(int(args.batch_size))
    total_batches = count_dataset_batches(combined_dataset)
    print('Batch size: {}, Total batches per epoch: {}'.format(
        args.batch_size, total_batches))

    generator_a2b = make_generator_model(n_blocks=9)
    generator_b2a = make_generator_model(n_blocks=9)
    discriminator_b = make_discriminator_model()
    discriminator_a = make_discriminator_model()
    gen_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches,
                                   DECAY_EPOCHS * total_batches)
    dis_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches,
                                   DECAY_EPOCHS * total_batches)
    optimizer_gen = tf.keras.optimizers.Adam(gen_lr_scheduler, BETA_1)
    optimizer_dis = tf.keras.optimizers.Adam(dis_lr_scheduler, BETA_1)

    checkpoint_dir = './checkpoints-{}'.format(args.dataset)
    checkpoint = tf.train.Checkpoint(
        generator_a2b=generator_a2b,
        generator_b2a=generator_b2a,
        discriminator_b=discriminator_b,
        discriminator_a=discriminator_a,
        optimizer_gen=optimizer_gen,
        optimizer_dis=optimizer_dis,
        epoch=tf.Variable(0))
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint, checkpoint_dir, max_to_keep=None)
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
        print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    @tf.function
    def train_generator(images_a, images_b):
        real_a = images_a
        real_b = images_b

        with tf.GradientTape() as tape:
            # Cycle A -> B -> A
            fake_a2b = generator_a2b(real_a, training=True)
            recon_b2a = generator_b2a(fake_a2b, training=True)
            # Cycle B -> A -> B
            fake_b2a = generator_b2a(real_b, training=True)
            recon_a2b = generator_a2b(fake_b2a, training=True)

            # Use real B to generate B should be identical
            identity_a2b = generator_a2b(real_b, training=True)
            identity_b2a = generator_b2a(real_a, training=True)
            loss_identity_a2b = calc_identity_loss(identity_a2b, real_b)
            loss_identity_b2a = calc_identity_loss(identity_b2a, real_a)

            # Generator A2B tries to trick Discriminator B that the generated image is B
            loss_gan_gen_a2b = calc_gan_loss(
                discriminator_b(fake_a2b, training=True), True)
            # Generator B2A tries to trick Discriminator A that the generated image is A
            loss_gan_gen_b2a = calc_gan_loss(
                discriminator_a(fake_b2a, training=True), True)
            loss_cycle_a2b2a = calc_cycle_loss(recon_b2a, real_a)
            loss_cycle_b2a2b = calc_cycle_loss(recon_a2b, real_b)

            # Total generator loss
            loss_gen_total = loss_gan_gen_a2b + loss_gan_gen_b2a \
                + (loss_cycle_a2b2a + loss_cycle_b2a2b) * LAMBDA_CYCLE \
                + (loss_identity_a2b + loss_identity_b2a) * LAMBDA_ID

        trainable_variables = generator_a2b.trainable_variables + generator_b2a.trainable_variables
        gradient_gen = tape.gradient(loss_gen_total, trainable_variables)
        optimizer_gen.apply_gradients(zip(gradient_gen, trainable_variables))

        # Metrics
        loss_gen_a2b_metrics(loss_gan_gen_a2b)
        loss_gen_b2a_metrics(loss_gan_gen_b2a)
        loss_id_b2a_metrics(loss_identity_b2a)
        loss_id_a2b_metrics(loss_identity_a2b)
        loss_cycle_a2b2a_metrics(loss_cycle_a2b2a)
        loss_cycle_b2a2b_metrics(loss_cycle_b2a2b)
        loss_gen_total_metrics(loss_gen_total)

        loss_dict = {
            'loss_gen_a2b': loss_gan_gen_a2b,
            'loss_gen_b2a': loss_gan_gen_b2a,
            'loss_id_a2b': loss_identity_a2b,
            'loss_id_b2a': loss_identity_b2a,
            'loss_cycle_a2b2a': loss_cycle_a2b2a,
            'loss_cycle_b2a2b': loss_cycle_b2a2b,
            'loss_gen_total': loss_gen_total,
        }
        return fake_a2b, fake_b2a, loss_dict

    @tf.function
    def train_discriminator(images_a, images_b, fake_a2b, fake_b2a):
        real_a = images_a
        real_b = images_b

        with tf.GradientTape() as tape:

            # Discriminator A should classify real_a as A
            loss_gan_dis_a_real = calc_gan_loss(
                discriminator_a(real_a, training=True), True)
            # Discriminator A should classify generated fake_b2a as not A
            loss_gan_dis_a_fake = calc_gan_loss(
                discriminator_a(fake_b2a, training=True), False)

            # Discriminator B should classify real_b as B
            loss_gan_dis_b_real = calc_gan_loss(
                discriminator_b(real_b, training=True), True)
            # Discriminator B should classify generated fake_a2b as not B
            loss_gan_dis_b_fake = calc_gan_loss(
                discriminator_b(fake_a2b, training=True), False)

            # Total discriminator loss
            loss_dis_a = (loss_gan_dis_a_real + loss_gan_dis_a_fake) * 0.5
            loss_dis_b = (loss_gan_dis_b_real + loss_gan_dis_b_fake) * 0.5
            loss_dis_total = loss_dis_a + loss_dis_b

        trainable_variables = discriminator_a.trainable_variables + discriminator_b.trainable_variables
        gradient_dis = tape.gradient(loss_dis_total, trainable_variables)
        optimizer_dis.apply_gradients(zip(gradient_dis, trainable_variables))

        # Metrics
        loss_dis_a_metrics(loss_dis_a)
        loss_dis_b_metrics(loss_dis_b)
        loss_dis_total_metrics(loss_dis_total)

        return {
            'loss_dis_b': loss_dis_b,
            'loss_dis_a': loss_dis_a,
            'loss_dis_total': loss_dis_total
        }

    def train_step(images_a, images_b, epoch, step):
        fake_a2b, fake_b2a, gen_loss_dict = train_generator(images_a, images_b)

        fake_b2a_from_pool = fake_pool_b2a.query(fake_b2a)
        fake_a2b_from_pool = fake_pool_a2b.query(fake_a2b)

        dis_loss_dict = train_discriminator(
            images_a, images_b, fake_a2b_from_pool, fake_b2a_from_pool)

        gen_loss_list = [
            '{}:{} '.format(k, v) for k, v in gen_loss_dict.items()
        ]
        dis_loss_list = [
            '{}:{} '.format(k, v) for k, v in dis_loss_dict.items()
        ]

        tf.print('Epoch {} Step {} '.format(epoch, step),
                 ' '.join(gen_loss_list + dis_loss_list))

    current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = 'logs/{}/{}/train'.format(args.dataset, current_time)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    def write_metrics(epoch):
        with train_summary_writer.as_default():
            tf.summary.scalar(
                'loss_gen_a2b', loss_gen_a2b_metrics.result(), step=epoch)
            tf.summary.scalar(
                'loss_gen_b2a', loss_gen_b2a_metrics.result(), step=epoch)
            tf.summary.scalar(
                'loss_dis_b', loss_dis_b_metrics.result(), step=epoch)
            tf.summary.scalar(
                'loss_dis_a', loss_dis_a_metrics.result(), step=epoch)
            tf.summary.scalar(
                'loss_id_a2b', loss_id_a2b_metrics.result(), step=epoch)
            tf.summary.scalar(
                'loss_id_b2a', loss_id_b2a_metrics.result(), step=epoch)
            tf.summary.scalar(
                'loss_gen_total', loss_gen_total_metrics.result(), step=epoch)
            tf.summary.scalar(
                'loss_dis_total', loss_dis_total_metrics.result(), step=epoch)
            tf.summary.scalar(
                'loss_cycle_a2b2a',
                loss_cycle_a2b2a_metrics.result(),
                step=epoch)
            tf.summary.scalar(
                'loss_cycle_b2a2b',
                loss_cycle_b2a2b_metrics.result(),
                step=epoch)
            tf.summary.scalar(
                'gen_learning_rate',
                gen_lr_scheduler.current_learning_rate,
                step=epoch)
            tf.summary.scalar(
                'dis_learning_rate',
                dis_lr_scheduler.current_learning_rate,
                step=epoch)

        loss_gen_a2b_metrics.reset_states()
        loss_gen_b2a_metrics.reset_states()
        loss_dis_b_metrics.reset_states()
        loss_dis_a_metrics.reset_states()
        loss_id_a2b_metrics.reset_states()
        loss_id_b2a_metrics.reset_states()
        return

    def train(dataset, epochs):
        for epoch in range(checkpoint.epoch + 1, epochs + 1):
            start = time.time()
            print('Epoch {} starts. Learning rate: {}, {}'.format(
                epoch, gen_lr_scheduler.current_learning_rate,
                dis_lr_scheduler.current_learning_rate))

            # Training
            for (step, batch) in enumerate(dataset):
                train_step(batch[0], batch[1], epoch, step)

            # Update TensorBoard metrics
            write_metrics(epoch)

            # Save checkpoint
            checkpoint.epoch.assign_add(1)
            if epoch % 2 == 0:
                save_path = checkpoint_manager.save()
                print("Saved checkpoint for epoch {}: {}".format(
                    int(checkpoint.epoch), save_path))

            print('Time for epoch {} is {} sec'.format(epoch,
                                                       time.time() - start))

    # for local testing
    # seed1 = tf.random.normal([2, 256, 256, 3])
    # seed2 = tf.random.normal([2, 256, 256, 3])
    # combined_dataset = [(seed1, seed2)]
    # EPOCHS = 102

    train(combined_dataset, EPOCHS)
    print('Finished training.')
Beispiel #5
0
    shutil.rmtree(itr_out_dir)
os.mkdir(itr_out_dir)  # to save temp output

# redirect print to a file
if args.print == 0:
    sys.stdout = open('%s/%s' % (itr_out_dir, 'iter-prints.log'), 'w')

# build minibatch data generator with prefetch
mb_data_iter = bkgdGen(data_generator=gen_train_batch_bg(
                                      dsfn=args.dsfn, mb_size=args.mbsz, \
                                      in_depth=args.depth, img_size=args.psz), \
                       max_prefetch=args.mbsz*4)

generator = make_generator_model(input_shape=(None, None, args.depth),
                                 nlayers=args.lunet)
discriminator = make_discriminator_model(input_shape=(args.psz, args.psz, 1))

feature_extractor_vgg = tf.keras.applications.VGG19(\
                        weights='vgg19_weights_notop.h5', \
                        include_top=False)

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
Beispiel #6
0
print(f"batch_size {batch_size}")
EPOCHS = 900
num_examples_to_generate = 25
noise_dim = 200
decay_step = 10
lr_initial_g = 0.0002
lr_decay_steps = 1000
replay_step = 32

seed = tf.random.normal([num_examples_to_generate, noise_dim])

noise_var = tf.Variable(initial_value=0.005, trainable=False, name="noiseIn")
noise_var.assign(0.005)

generator = make_generator_model(original_w, noise_dim)
discriminator = make_discriminator_model(original_w, noise_var)

if True:
    noise = tf.random.normal([1, noise_dim])
    generated_image = generator(noise, training=False)
    decision = discriminator(generated_image)
    print(decision)
    # plt.imshow( generated_image[0] *0.5 + 0.5  )
    # plt.show()
generator.summary()
discriminator.summary()

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def invert_if(v1):