Beispiel #1
0
def define_objective(charmap, real_inputs_discrete, seq_length):
    real_inputs = tf.one_hot(real_inputs_discrete, len(charmap))
    Generator = get_generator(FLAGS.GENERATOR_MODEL)
    Discriminator = get_discriminator(FLAGS.DISCRIMINATOR_MODEL)
    train_pred, inference_op = Generator(BATCH_SIZE,
                                         len(charmap),
                                         seq_len=seq_length,
                                         gt=real_inputs)

    real_inputs_substrings = get_substrings_from_gt(real_inputs, seq_length,
                                                    len(charmap))

    disc_real = Discriminator(real_inputs_substrings,
                              len(charmap),
                              seq_length,
                              reuse=False)
    disc_fake = Discriminator(train_pred, len(charmap), seq_length, reuse=True)
    disc_on_inference = Discriminator(inference_op,
                                      len(charmap),
                                      seq_length,
                                      reuse=True)

    disc_cost, gen_cost = loss_d_g(disc_fake, disc_real, train_pred,
                                   real_inputs_substrings, charmap, seq_length,
                                   Discriminator)
    return disc_cost, gen_cost, train_pred, disc_fake, disc_real, disc_on_inference, inference_op
Beispiel #2
0
def define_objective(charmap, real_inputs_discrete, seq_length, gan_type="wgan", rnn_cell=None):
    assert gan_type in ["wgan", "fgan", "cgan"]
    assert rnn_cell
    other_ops = {}
    real_inputs = tf.one_hot(real_inputs_discrete, len(charmap))
    Generator = get_generator(FLAGS.GENERATOR_MODEL)
    Discriminator = get_discriminator(FLAGS.DISCRIMINATOR_MODEL)
    train_pred, inference_op = Generator(BATCH_SIZE, len(charmap), seq_len=seq_length, gt=real_inputs, rnn_cell=rnn_cell)

    real_inputs_substrings = get_substrings_from_gt(real_inputs, seq_length, len(charmap))

    disc_real = Discriminator(real_inputs_substrings, len(charmap), seq_length, reuse=False,
        rnn_cell=rnn_cell)
    disc_fake = Discriminator(train_pred, len(charmap), seq_length, reuse=True,
        rnn_cell=rnn_cell)
    disc_on_inference = Discriminator(inference_op, len(charmap), seq_length, reuse=True,
        rnn_cell=rnn_cell)


    if gan_type == "wgan":
        disc_cost, gen_cost = loss_d_g(disc_fake, disc_real, train_pred, real_inputs_substrings, charmap, seq_length, Discriminator, rnn_cell)
    elif gan_type == "fgan":
        fgan = FisherGAN()
        disc_cost, gen_cost = fgan.loss_d_g(disc_fake, disc_real, train_pred, real_inputs_substrings, charmap, seq_length, Discriminator)
        other_ops["alpha_optimizer_op"] = fgan.alpha_optimizer_op
    else:
        raise NotImplementedError("Cramer GAN not implemented")

    return disc_cost, gen_cost, train_pred, disc_fake, disc_real, disc_on_inference, inference_op, other_ops
Beispiel #3
0
def train():
    images, images_path = get_celebA(flags.output_size, flags.n_epoch,
                                     flags.batch_size)
    G = get_generator([None, flags.z_dim])
    D = get_discriminator(
        [None, flags.output_size, flags.output_size, flags.c_dim])

    G.train()
    D.train()

    d_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
    g_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)

    n_step_epoch = int(len(images_path) // flags.batch_size)

    for epoch in range(flags.n_epoch):
        for step, batch_images in enumerate(images):
            if batch_images.shape[0] != flags.batch_size:
                break

            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                z = np.random.normal(loc=0.0,
                                     scale=1.0,
                                     size=[flags.batch_size,
                                           flags.z_dim]).astype(np.float32)
                d_logits = D(G(z))
                d2_logits = D(batch_images)
                # discriminator: real images are labelled as 1
                d_loss_real = tl.cost.sigmoid_cross_entropy(
                    d2_logits, tf.ones_like(d2_logits), name='dreal')
                # discriminator: images from generator (fake) are labelled as 0
                d_loss_fake = tl.cost.sigmoid_cross_entropy(
                    d_logits, tf.zeros_like(d_logits), name='dfake')
                # combined loss for updating discriminator
                d_loss = d_loss_real + d_loss_fake
                # generator: try to fool discriminator to output 1
                g_loss = tl.cost.sigmoid_cross_entropy(d_logits,
                                                       tf.ones_like(d_logits),
                                                       name='gfake')

            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
            grad = tape.gradient(d_loss, D.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
            del tape

            if step % flags.print_every_step == 0:
                print("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".format(epoch, \
                      flags.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss))

        if np.mod(epoch, flags.save_every_epoch) == 0:
            G.save_weights('{}/G_{}.h5'.format(flags.checkpoint_dir, epoch))
            D.save_weights('{}/D_{}.h5'.format(flags.checkpoint_dir, epoch))
            G.eval()
            result = G(z)
            G.train()
            tl.visualize.save_images(
                result.numpy(), [num_tiles, num_tiles],
                '{}/train_{:02d}.png'.format(flags.sample_dir, epoch))
Beispiel #4
0
    def build_model(self):
        self.G = get_generator(self.g_conv_dim, self.n_labels,
                               self.g_repeat_num, self.image_size)
        self.D = get_discriminator(self.d_conv_dim, self.n_labels,
                                   self.d_repeat_num, self.image_size)

        print(self.G.summary())

        self.d_optimizer = keras.optimizers.Adam(lr=self.d_lr,
                                                 beta_1=self.beta_1,
                                                 beta_2=self.beta_2)
        self.g_optimizer = keras.optimizers.Adam(lr=self.g_lr,
                                                 beta_1=self.beta_2,
                                                 beta_2=self.beta_2)

        self.D.trainable = False

        combined_real_img = Input(shape=(self.image_size, self.image_size, 3))
        input_orig_labels = Input(shape=(self.image_size, self.image_size,
                                         self.n_labels))
        input_target_labels = Input(shape=(self.image_size, self.image_size,
                                           self.n_labels))

        concatted_input = Concatenate(axis=3)(
            [combined_real_img, input_target_labels])

        combined_fake_img = self.G(concatted_input)
        output_src, output_cls = self.D(combined_fake_img)
        concatted_combined_fake_img = Concatenate(axis=3)(
            [combined_fake_img, input_orig_labels])
        reconstr_img = self.G(concatted_combined_fake_img)

        self.combined = Model(
            inputs=[combined_real_img, input_orig_labels, input_target_labels],
            outputs=[reconstr_img, output_src, output_cls])

        self.combined.compile(
            loss=["mae", neg_mean_loss, self.custom_bin],
            loss_weights=[self.lambda_rec, 1, self.lambda_cls],
            optimizer=self.g_optimizer)

        shape = (self.image_size, self.image_size, 3)
        fake_input, real_input, interpolation = Input(shape), Input(
            shape), Input(shape)
        norm = GradNorm()([self.D(interpolation)[0], interpolation])
        fake_output_src, fake_output_cls = self.D(fake_input)
        real_output_src, real_output_cls = self.D(real_input)
        self.DIS = Model(
            [real_input, fake_input, interpolation],
            [fake_output_src, real_output_src, real_output_cls, norm])
        # self.DIS = Model([gen_input], output_D)

        self.D.trainable = True

        self.DIS.compile(
            loss=[mean_loss, neg_mean_loss, self.custom_bin, 'mse'],
            loss_weights=[1, 1, self.lambda_cls, self.lambda_gp],
            optimizer=self.d_optimizer)
Beispiel #5
0
def train():
    z = tf.contrib.distributions.Normal(0., 1.).sample([FLAGS.batch_size, FLAGS.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
    ds, images_path = get_celebA(FLAGS.output_size, FLAGS.n_epoch, FLAGS.batch_size)
    iterator = ds.make_one_shot_iterator()
    images = iterator.get_next()

    G = get_generator([None, FLAGS.z_dim])
    D = get_discriminator([None, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim])

    G.train()
    D.train()
    fake_images = G(z)
    d_logits = D(fake_images)
    d2_logits = D(images)

    # discriminator: real images are labelled as 1
    d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')
    # discriminator: images from generator (fake) are labelled as 0
    d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
    # cost for updating discriminator
    d_loss = d_loss_real + d_loss_fake

    # generator: try to make the the fake images look real (1)
    g_loss = tl.cost.sigmoid_cross_entropy(d_logits, tf.ones_like(d_logits), name='gfake')
    # Define optimizers for updating discriminator and generator
    d_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1) \
                      .minimize(d_loss, var_list=D.weights)
    g_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1) \
                      .minimize(g_loss, var_list=G.weights)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    n_step_epoch = int(len(images_path) // FLAGS.batch_size)
    for epoch in range(FLAGS.n_epoch):
        epoch_time = time.time()
        for step in range(n_step_epoch):
            step_time = time.time()
            _d_loss, _g_loss, _, _ = sess.run([d_loss, g_loss, d_optim, g_optim])
            print("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".format(epoch, FLAGS.n_epoch, step, n_step_epoch, time.time()-step_time, _d_loss, _g_loss))
            if np.mod(step, FLAGS.save_step) == 0:
                G.save_weights('{}/G.npz'.format(FLAGS.checkpoint_dir), sess=sess, format='npz')
                D.save_weights('{}/D.npz'.format(FLAGS.checkpoint_dir), sess=sess, format='npz')
                result = sess.run(fake_images)
                tl.visualize.save_images(result, [num_tiles, num_tiles], '{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, step))

    sess.close()
Beispiel #6
0
def define_class_objective(charmap, real_inputs_discrete, real_class_discrete,
                           seq_length, num_classes):
    real_inputs = tf.one_hot(real_inputs_discrete, len(charmap))
    Generator = get_generator(FLAGS.GENERATOR_MODEL)
    Discriminator = get_discriminator(FLAGS.DISCRIMINATOR_MODEL)
    train_pred, train_pred_class, inference_op = Generator(
        BATCH_SIZE,
        len(charmap),
        seq_len=seq_length,
        num_classes=num_classes,
        gt=real_inputs,
        gt_class=real_class_discrete)

    real_inputs_substrings, real_inputs_class = get_substrings_from_gt(
        real_inputs, real_class_discrete, seq_length, len(charmap))

    disc_real, disc_real_class = Discriminator(real_inputs_substrings,
                                               len(charmap),
                                               seq_length,
                                               num_classes,
                                               reuse=False)
    disc_fake, disc_fake_class = Discriminator(train_pred,
                                               len(charmap),
                                               seq_length,
                                               num_classes,
                                               reuse=True)
    disc_on_inference, disc_on_inference_class = Discriminator(inference_op,
                                                               len(charmap),
                                                               seq_length,
                                                               num_classes,
                                                               reuse=True)

    disc_cost, gen_cost = loss_d_g_class(disc_fake=disc_fake,
                                         disc_fake_class=disc_fake_class,
                                         disc_real=disc_real,
                                         disc_real_class=disc_real_class,
                                         gt_fake_class=train_pred_class,
                                         gt_real_class=real_inputs_class,
                                         num_classes=num_classes)

    return disc_cost, gen_cost, train_pred, disc_fake, disc_fake_class, \
           disc_real, disc_real_class, disc_on_inference, disc_on_inference_class, inference_op
Beispiel #7
0
 def __init__(self, flags, type):
     self.dataset, self.len_instance = get_mnist(flags.batch_size)
     self.G = get_generator([None, flags.z_dim],
                            gf_dim=64,
                            o_size=flags.output_size,
                            o_channel=flags.c_dim)
     self.D = get_discriminator(
         [None, flags.output_size, flags.output_size, flags.c_dim],
         df_dim=64)
     self.batch_size = flags.batch_size
     self.epoch = flags.n_epoch
     self.type = type
     assert type in methods_dict.keys()
     self.get_loss = methods_dict[type]
     if type == "WGAN":
         self.d_optimizer = tf.optimizers.RMSprop(flags.lr)
         self.g_optimizer = tf.optimizers.RMSprop(flags.lr)
     else:
         self.d_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
         self.g_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
Beispiel #8
0
def train(imgs_train_path,imgs_test_path,output_models_path,output_images_path,batch_size,epochs,epoch_size = 1000,training_images_to_load = 3000,test_images = 50,training_for_generator_each_batch=2,save_data= True):

    train_x,train_y = utility.get_data(imgs_train_path,IMAGE_SHAPE,training_images_to_load)
    test_x,test_y = utility.get_data(imgs_test_path,IMAGE_SHAPE,test_images,False)

    if save_data:
        utility.save_data_as_pickle(train_x,"dataset/operative_data/train_x")
        utility.save_data_as_pickle(train_y,"dataset/operative_data/train_y")
        utility.save_data_as_pickle(test_x,"dataset/operative_data/test_x")
        utility.save_data_as_pickle(test_y,"dataset/operative_data/test_y")
        print("--SAVED DATA")


    generator = model.get_generator(IMAGE_SHAPE)
    discriminator = model.get_discriminator(IMAGE_SHAPE)
    loss_generator = loss.VGG_LOSS(IMAGE_SHAPE)
    generator.compile(loss=loss_generator.vgg_loss, optimizer=OPTIMIZER)
    discriminator.compile(loss="binary_crossentropy", optimizer=OPTIMIZER)
    gan = model.get_gan_model(DOWNSCALED_IMG_SHAPE,generator,discriminator,loss_generator.vgg_loss,"binary_crossentropy",OPTIMIZER)
    gan.summary()

    n_batch = int((epoch_size)/ batch_size)
    n_batch_test = int((test_images/batch_size))

    true_batch_vector = np.ones((batch_size,1))
    false_batch_vector = np.zeros((batch_size,1))
    

    print("--START TRAINING")
    for epoch in range(epochs):

        disciminator_losses = []
        gan_losses = []
        epoch_start_time = time.time()

        # train each batch
        for batch in range(n_batch):
            random_indexes = np.random.randint(0, len(train_x), size=batch_size)

            batch_x  =  np.array(train_x)[random_indexes.astype(int)]
            batch_y =  np.array(train_y)[random_indexes.astype(int)]

            generated_images = generator.predict(x=batch_x, batch_size=batch_size)

            discriminator.trainable = True

            #we can decide to perform more than one train for batch on the discriminator 
            for _ in range(training_for_generator_each_batch):
                d_loss_r = discriminator.train_on_batch(batch_y, true_batch_vector)
                d_loss_f = discriminator.train_on_batch(generated_images, np.random.random_sample(batch_size)*0.2)      

                disciminator_losses.append(0.5 * np.add(d_loss_f, d_loss_r))

            discriminator.trainable = False
            
            # train the generator 
            gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
            gan_loss = gan.train_on_batch(batch_x, [batch_y, gan_Y])
            gan_losses.append(gan_loss)

        test_losses = []
        for i in range(n_batch_test):
            batch_x = np.array(test_x)[i*batch_size:(i+1)*batch_size]
            batch_y = np.array(test_y)[i*batch_size:(i+1)*batch_size]
            gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
            gan_loss = gan.test_on_batch(batch_x, [batch_y, gan_Y])
            test_losses.append(gan_loss)

        #print("discriminator loss: ",np.mean(disciminator_losses), " gan losses: ",[np.mean(x) for x in zip(*gan_losses)] ," time: ",time.time()-epoch_start_time)
        print("test: ",[np.mean(x) for x in zip(*test_losses)])

        if epoch % 3 == 0 or epoch == 0:
            generator.save(output_models_path + 'gen_model%d.h5' % epoch)
            discriminator.save(output_models_path + 'dis_model%d.h5' % epoch)
            utility.plot_generated_images(output_images_path,epoch,generator,test_y,test_x)
Beispiel #9
0
def main(args):
    result_dir_path = Path(args.result_dir)
    result_dir_path.mkdir(parents=True, exist_ok=True)

    with Path(args.setting).open("r") as f:
        setting = json.load(f)
    pprint.pprint(setting)

    if args.g >= 0 and torch.cuda.is_available():
        device = torch.device(f"cuda:{args.g:d}")
        print(f"GPU mode: {args.g:d}")
    else:
        device = torch.device("cpu")
        print("CPU mode")

    mnist_neg = get_mnist_num(set(setting["label"]["neg"]))
    neg_loader = DataLoader(mnist_neg,
                            batch_size=setting["iterator"]["batch_size"])

    generator = get_generator().to(device)
    discriminator = get_discriminator().to(device)
    opt_g = torch.optim.Adam(
        generator.parameters(),
        lr=setting["optimizer"]["alpha"],
        betas=(setting["optimizer"]["beta1"], setting["optimizer"]["beta2"]),
        weight_decay=setting["regularization"]["weight_decay"])
    opt_d = torch.optim.Adam(
        discriminator.parameters(),
        lr=setting["optimizer"]["alpha"],
        betas=(setting["optimizer"]["beta1"], setting["optimizer"]["beta2"]),
        weight_decay=setting["regularization"]["weight_decay"])

    trainer = Engine(
        GANTrainer(generator,
                   discriminator,
                   opt_g,
                   opt_d,
                   device=device,
                   **setting["updater"]))

    # テスト用
    test_neg = get_mnist_num(set(setting["label"]["neg"]), train=False)
    test_neg_loader = DataLoader(test_neg, setting["iterator"]["batch_size"])
    test_pos = get_mnist_num(set(setting["label"]["pos"]), train=False)
    test_pos_loader = DataLoader(test_pos, setting["iterator"]["batch_size"])
    detector = Detector(generator, discriminator,
                        setting["updater"]["noise_std"], device).to(device)

    log_dict = {}
    evaluator = evaluate_accuracy(log_dict, detector, test_neg_loader,
                                  test_pos_loader, device)
    plotter = plot_metrics(log_dict, ["accuracy", "precision", "recall", "f"],
                           "iteration", result_dir_path / "metrics.pdf")
    printer = print_logs(log_dict,
                         ["iteration", "accuracy", "precision", "recall", "f"])
    img_saver = save_img(generator, test_pos, test_neg,
                         result_dir_path / "images",
                         setting["updater"]["noise_std"], device)

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000),
                              evaluator)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), plotter)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), printer)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000),
                              img_saver)

    # 指定されたiterationで終了
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(once=setting["iteration"]),
        lambda engine: engine.terminate())
    trainer.run(neg_loader, max_epochs=10**10)
def train():
    # Horovod: initialize Horovod.
    hvd.init()
    # Horovod: pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    tf.enable_eager_execution(config=config)
    # Horovod: adjust number of steps based on number of GPUs.
    images, images_path = get_celebA(FLAGS.output_size, FLAGS.n_epoch // hvd.size(), FLAGS.batch_size)

    G = get_generator([None, FLAGS.z_dim])
    D = get_discriminator([None, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim])

    G.train()
    D.train()

    d_optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate * hvd.size(), beta1=FLAGS.beta1) # linear scaling rule
    g_optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate * hvd.size(), beta1=FLAGS.beta1)

    step_counter = tf.train.get_or_create_global_step()

    n_step_epoch = int(len(images_path) // FLAGS.batch_size)

    for step, batch_images in enumerate(images):
        step_time = time.time()
        with tf.GradientTape(persistent=True) as tape:
            z = tf.contrib.distributions.Normal(0., 1.).sample([FLAGS.batch_size, FLAGS.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
            d_logits = D(G(z))
            d2_logits = D(batch_images)
            # discriminator: real images are labelled as 1
            d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')
            # discriminator: images from generator (fake) are labelled as 0
            d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
            # cost for updating discriminator
            d_loss = d_loss_real + d_loss_fake
            # generator: try to make the the fake images look real (1)
            g_loss = tl.cost.sigmoid_cross_entropy(d_logits, tf.ones_like(d_logits), name='gfake')

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        if step == 0:
            hvd.broadcast_variables(G.weights, root_rank=0)
            hvd.broadcast_variables(D.weights, root_rank=0)

        # Horovod: add Horovod Distributed GradientTape.
        tape = hvd.DistributedGradientTape(tape)
        #
        grad = tape.gradient(d_loss, D.weights)
        d_optimizer.apply_gradients(zip(grad, D.weights), global_step=tf.train.get_or_create_global_step())
        grad = tape.gradient(g_loss, G.weights)
        g_optimizer.apply_gradients(zip(grad, G.weights), global_step=tf.train.get_or_create_global_step())

        # Horovod: print logging only on worker 0
        if hvd.rank() == 0
            print("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".format(step//n_step_epoch, FLAGS.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss))

        # Horovod: save checkpoints only on worker 0
        if hvd.rank() == 0 and np.mod(step, FLAGS.save_step) == 0:
            G.save_weights('{}/G.npz'.format(FLAGS.checkpoint_dir), format='npz')
            D.save_weights('{}/D.npz'.format(FLAGS.checkpoint_dir), format='npz')
            result = G(z)
            tl.visualize.save_images(result.numpy(), [num_tiles, num_tiles], '{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, step//n_step_epoch, step))
Beispiel #11
0
def train():
    images, images_path = get_celebA(flags.output_size, flags.n_epoch,
                                     flags.batch_size)
    G = get_generator([None, flags.z_dim])
    D = get_discriminator(
        [None, flags.z_dim],
        [None, flags.output_size, flags.output_size, flags.c_dim])
    E = get_encoder([None, flags.output_size, flags.output_size, flags.c_dim])

    if flags.load_weights:
        E.load_weights('checkpoint/E.npz', format='npz')
        G.load_weights('checkpoint/G.npz', format='npz')
        D.load_weights('checkpoint/D.npz', format='npz')

    G.train()
    D.train()
    E.train()

    d_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
    g_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
    e_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)

    n_step_epoch = int(len(images_path) // flags.batch_size)

    for epoch in range(flags.n_epoch):
        for step, batch_images in enumerate(images):
            if batch_images.shape[0] != flags.batch_size:
                break
            step_time = time.time()

            with tf.GradientTape(persistent=True) as tape:
                z = np.random.normal(loc=0.0,
                                     scale=1.0,
                                     size=[flags.batch_size,
                                           flags.z_dim]).astype(np.float32)

                d_logits = D([G(z), z])
                d2_logits = D([batch_images, E(batch_images)])

                d_loss_real = tl.cost.sigmoid_cross_entropy(
                    d2_logits, tf.ones_like(d2_logits), name='dreal')
                d_loss_fake = tl.cost.sigmoid_cross_entropy(
                    d_logits, tf.zeros_like(d_logits), name='dfake')
                d_loss = d_loss_fake + d_loss_real

                g_loss = tl.cost.sigmoid_cross_entropy(d_logits,
                                                       tf.ones_like(d_logits),
                                                       name='gfake')

                e_loss = tl.cost.sigmoid_cross_entropy(
                    d2_logits, tf.zeros_like(d2_logits), name='ereal')

            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
            grad = tape.gradient(d_loss, D.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
            grad = tape.gradient(e_loss, E.trainable_weights)
            e_optimizer.apply_gradients(zip(grad, E.trainable_weights))

            del tape

            print(
                "Epoch: [{}/{}] [{}/{}] took: {:.3f}, d_loss: {:.5f}, g_loss: {:.5f}, e_loss: {:.5f}"
                .format(epoch, flags.n_epoch, step, n_step_epoch,
                        time.time() - step_time, d_loss, g_loss, e_loss))

        if np.mod(epoch, flags.save_every_epoch) == 0:
            G.save_weights('{}/G.npz'.format(flags.checkpoint_dir),
                           format='npz')
            D.save_weights('{}/D.npz'.format(flags.checkpoint_dir),
                           format='npz')
            E.save_weights('{}/E.npz'.format(flags.checkpoint_dir),
                           format='npz')
            G.eval()
            result = G(z)
            G.train()
            tl.visualize.save_images(
                result.numpy(), [num_tiles, num_tiles],
                '{}/train_{:02d}.png'.format(flags.sample_dir, epoch))

            for step, batch_images in enumerate(images):
                if batch_images.shape[0] != flags.batch_size:
                    break
                result = G(E(batch_images))
                tl.visualize.save_images(
                    batch_images.numpy(), [num_tiles, num_tiles],
                    '{}/real_{:02d}.png'.format(flags.pair_dir, epoch))
                tl.visualize.save_images(
                    result.numpy(), [num_tiles, num_tiles],
                    '{}/reproduced_{:02d}.png'.format(flags.pair_dir, epoch))
                break
Beispiel #12
0
def train():

    #load data
    face, au = load_data(face_dir, au_dir)
    au_rand = au.copy()
    np.random.shuffle(au_rand)
    au_rand += np.random.uniform(-0.1, 0.1, au_rand.shape)

    G = get_generator([None, 128, 128, 20])
    D = get_discriminator([None, 128, 128, 3])

    lr = 1e-4
    g_train_op = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999)
    d_train_op = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999)

    G.train()
    D.train()

    n_step_epoch = int(len(face) // BATCH_SIZE)
    num_tiles = int(np.sqrt(25))
    print('----------- start training -----------')
    for e in range(1, EPOCHS + 1):
        start_time = time.time()
        #lr设置

        print('===== [Epoch %02d/30](lr: %.5f) =====' % (e, lr_fn(e)))

        for i in range(len(face) // BATCH_SIZE):
            with tf.GradientTape(persistent=True) as tape:
                #获取一个batch的数据
                real_img = face[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
                real_au = au[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
                desired_au = au_rand[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]

                g_input_f = preprocess_data(real_img, desired_au)
                [fake_img, fake_mask] = G(g_input_f)
                fake_img_masked = fake_mask * real_img + (1 -
                                                          fake_mask) * fake_img

                g_input_cyc = preprocess_data(fake_img_masked, real_au)
                [cyc_img, cyc_mask] = G(g_input_cyc)
                cyc_img_masked = cyc_mask * fake_img_masked + (
                    1 - cyc_mask) * cyc_img

                # D(real_I)
                [pred_real_img, pred_real_au] = D(real_img)
                # D(fake_I)
                [pred_fake_img_masked, pred_fake_au] = D(fake_img_masked)
                pred_real_au = tf.squeeze(input=pred_real_au, axis=[1, 2])
                pred_fake_au = tf.squeeze(input=pred_fake_au, axis=[1, 2])

                # loss
                loss_d_img = -tf.reduce_mean(
                    pred_real_img) * lambda_D_img + tf.reduce_mean(
                        pred_fake_img_masked) * lambda_D_img
                loss_d_au = l2_loss(real_au, pred_real_au) * lambda_D_au

                with tf.GradientTape(persistent=True) as tape1:
                    alpha = tf.compat.v1.random_uniform([BATCH_SIZE, 1, 1, 1],
                                                        minval=0.,
                                                        maxval=1.)
                    differences = fake_img_masked - real_img
                    interpolates = real_img + tf.multiply(alpha, differences)
                    out = D(interpolates)
                gradients = tape1.gradient(out, [interpolates])
                del tape1
                slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
                gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
                loss_d_gp = lambda_D_gp * gradient_penalty

                loss_d = loss_d_img + loss_d_au + loss_d_gp

            grad = tape.gradient(loss_d, D.trainable_weights)
            d_train_op.learning_rate = lr_fn(e)
            d_train_op.apply_gradients(zip(grad, D.trainable_weights))
            del tape

            d_summary_str = "Epoch: [{}/{}] [{}/{}] took: {:.3f}, d_loss: {:.5f}".format(
                e, EPOCHS, i, n_step_epoch,
                time.time() - start_time, loss_d)
            print(d_summary_str)

            if (i + 1) % 5 == 0:
                with tf.GradientTape(persistent=True) as tape:
                    real_img = face[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
                    real_au = au[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
                    desired_au = au_rand[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]

                    g_input_f = preprocess_data(real_img, desired_au)
                    [fake_img, fake_mask] = G(g_input_f)
                    fake_img_masked = fake_mask * real_img + (
                        1 - fake_mask) * fake_img
                    # G(G(Ic1, c2)*M, c1) * M
                    g_input_cyc = preprocess_data(fake_img_masked, real_au)
                    [cyc_img, cyc_mask] = G(g_input_cyc)
                    cyc_img_masked = cyc_mask * fake_img_masked + (
                        1 - cyc_mask) * cyc_img

                    # D(real_I)
                    [pred_real_img, pred_real_au] = D(real_img)
                    # D(fake_I)
                    [pred_fake_img_masked, pred_fake_au] = D(fake_img_masked)
                    pred_real_au = tf.squeeze(input=pred_real_au, axis=[1, 2])
                    pred_fake_au = tf.squeeze(input=pred_fake_au, axis=[1, 2])

                    # loss
                    loss_g_fake_img_masked = -tf.reduce_mean(
                        pred_fake_img_masked) * lambda_D_img
                    loss_g_fake_au = l2_loss(desired_au,
                                             pred_fake_au) * lambda_D_au
                    loss_g_cyc = l1_loss(real_img, cyc_img_masked) * lambda_cyc

                    loss_g_mask_fake = tf.reduce_mean(
                        fake_mask) * lambda_mask + smooth_loss(
                            fake_mask) * lambda_mask_smooth
                    loss_g_mask_cyc = tf.reduce_mean(
                        cyc_mask) * lambda_mask + smooth_loss(
                            cyc_mask) * lambda_mask_smooth

                    loss_g = loss_g_fake_img_masked + loss_g_fake_au + \
                             loss_g_cyc + \
                             loss_g_mask_fake + loss_g_mask_cyc

                grad = tape.gradient(loss_g, G.trainable_weights)
                g_train_op.learning_rate = lr_fn(e)
                g_train_op.apply_gradients(zip(grad, G.trainable_weights))
                del tape

                g_summary_str = "Epoch: [{}/{}] [{}/{}] took: {:.3f}, g_loss: {:.5f}".format(
                    e, EPOCHS, i, n_step_epoch,
                    time.time() - start_time, loss_g)
                print(g_summary_str)

        print('(spend time: %.2fmin) loss_g: %.4f  loss_d: %.4f \n' %
              ((time.time() - start_time) / 60, loss_g, loss_d))

        if np.mod(e, 10) == 0:
            tl.files.save_npz(G.all_weights, name='G_' + str(e) + '.npz')
            tl.files.save_npz(D.all_weights, name='D_' + str(e) + '.npz')
Beispiel #13
0
def train():
    print("GPU : ", torch.cuda.is_available())

    generator, optimizer_G, scheduler_G = model.get_generator(args)
    generator.to(device)

    discriminator, optimizer_D = model.get_discriminator(args)
    discriminator.to(device)

    if args.method == 'M3':
        discriminator2, optimizer_D2 = model.get_discrminator2(args)
        discriminator2.to(device)

    start_epoch = 0

    if args.resume_training:
        if args.debug:
            print("Resuming Training")
        checkpoint = torch.load(args.checkpoint_path)
        start_epoch = checkpoint['epoch']
        generator.load_state_dict(checkpoint['gen_state_dict'])
        optimizer_G.load_state_dict(checkpoint['gen_optimizer_dict'])
        scheduler_G.load_state_dict(checkpoint['gen_scheduler_dict'])
        discriminator.load_state_dict(checkpoint['dis_state_dict'])
        optimizer_D.load_state_dict(checkpoint['dis_optimizer_dict'])

    feature_extractor = VGGFeatureExtractor().to(device)

    # Set feature extractor to inference mode
    feature_extractor.eval()

    # Losses
    bce_loss = torch.nn.BCEWithLogitsLoss().to(device)
    l1_loss = torch.nn.L1Loss().to(device)
    l2_loss = torch.nn.MSELoss().to(device)

    # equal to negative of hypervolume of input losses
    hv_loss = HVLoss().to(device)

    # 1 - ssim(sr , hr)
    ssim_loss = SSIMLoss().to(device)

    dataloader = data.dataloader(args)
    test_dataloader = data.dataloader(args, train=False)

    batch_count = len(dataloader)

    generator.train()

    loss_len = 6
    if args.method == 'M4' or args.method == 'M7':
        loss_len = 7
    elif args.method == 'M6':
        loss_len = 9

    losses_log = np.zeros(loss_len + 1)

    for epoch in range(start_epoch, start_epoch + args.epochs):
        # print("*"*15 , "Epoch :" , epoch , "*"*15)
        losses_gen = np.zeros(loss_len)
        for i, imgs in tqdm(enumerate(dataloader)):
            batches_done = epoch * len(dataloader) + i

            # Configure model input
            imgs_hr = imgs["hr"].to(device)
            imgs_lr = imgs["lr"].to(device)

            # ------------------
            #  Train Generators
            # ------------------

            # optimize generator
            # discriminator.eval()
            for p in discriminator.parameters():
                p.requires_grad = False

            if args.method == 'M3':
                for p in discriminator2.parameters():
                    p.requires_grad = False

            optimizer_G.zero_grad()

            gen_hr = generator(imgs_lr)

            # Scaling/Clipping output
            # gen_hr = gen_hr.clamp(0,1)

            if batches_done < args.warmup_batches:
                # Measure pixel-wise loss against ground truth
                if args.warmup_loss == "L1":
                    loss_pixel = l1_loss(gen_hr, imgs_hr)
                elif args.warmup_loss == "L2":
                    loss_pixel = l2_loss(gen_hr, imgs_hr)
                # Warm-up (pixel-wise loss only)
                loss_pixel.backward()
                optimizer_G.step()
                if args.debug:
                    print("[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]" %
                          (epoch, args.epochs, i, len(dataloader),
                           loss_pixel.item()))
                continue

            # Extract validity predictions from discriminator
            pred_real = discriminator(imgs_hr).detach()
            pred_fake = discriminator(gen_hr)

            # Adversarial ground truths
            valid = torch.ones_like(pred_real)
            fake = torch.zeros_like(pred_real)

            if args.gan == 'RAGAN':
                # Adversarial loss (relativistic average GAN)
                loss_GAN = bce_loss(
                    pred_fake - pred_real.mean(0, keepdim=True), valid)
            elif args.gan == "VGAN":
                # Adversarial loss (vanilla GAN)
                loss_GAN = bce_loss(pred_fake, valid)

            if args.method == 'M3':
                # Extract validity predictions from discriminator
                pred_real2 = discriminator2(imgs_hr).detach()
                pred_fake2 = discriminator2(gen_hr)

                valid2 = torch.ones_like(pred_real2)
                fake2 = torch.zeros_like(pred_real2)
                if args.gan == 'RAGAN':
                    # Adversarial loss (relativistic average GAN)
                    loss_GAN2 = bce_loss(
                        pred_fake2 - pred_real2.mean(0, keepdim=True), valid2)
                elif args.gan == "VGAN":
                    # Adversarial loss (vanilla GAN)
                    loss_GAN2 = bce_loss(pred_fake2, valid2)

            # Content loss
            gen_features = feature_extractor(gen_hr)
            real_features = feature_extractor(imgs_hr).detach()
            if args.vgg_criterion == 'L1':
                loss_content = l1_loss(gen_features, real_features)
            elif args.vgg_criterion == 'L2':
                loss_content = l2_loss(gen_features, real_features)

            # For vgg hv loss ?? max-value
            # max_value = (1.1 * torch.max(torch.max(gen_features) , torch.max(real_features))).detach()
            # print(max_value , end = "\n\n")
            # loss_vgg_hv_psnr = hv_loss(1 - (psnr_fn(gen_features , real_features , max_value=max_value)/30) , 1 - ssim_fn_val(gen_features , real_features , max_value))

            psnr_val = psnr_fn(gen_hr, imgs_hr.detach())
            ssim_val = ssim_fn(gen_hr, imgs_hr.detach())

            # Total generator loss
            if args.method == 'M4':
                loss_hv_psnr = hv_loss(1 - (psnr_val / args.max_psnr),
                                       1 - ssim_val)
                loss_G = (loss_content * args.weight_vgg) + (
                    loss_hv_psnr * args.weight_hv) + (args.weight_gan *
                                                      loss_GAN)
            elif args.method == 'M1':
                loss_G = (loss_content * args.weight_vgg) + (args.weight_gan *
                                                             loss_GAN)
            elif args.method == 'M5':
                psnr_loss = (1 - (psnr_val / args.max_psnr)).mean()
                ssim_loss = (1 - ssim_val).mean()
                loss_G = (loss_content * args.weight_vgg) + (
                    args.weight_gan * loss_GAN) + (args.weight_pslinear *
                                                   (ssim_loss + psnr_loss))
            elif args.method == 'M6':
                real_features_normalized = normalize_VGG_features(
                    real_features)
                gen_features_normalized = normalize_VGG_features(gen_features)
                psnr_vgg_val = psnr_fn(gen_features_normalized,
                                       real_features_normalized)
                ssim_vgg_val = ssim_fn_vgg(gen_features_normalized,
                                           real_features_normalized)
                loss_vgg_hv = hv_loss(1 - (psnr_vgg_val / args.max_psnr),
                                      1 - ssim_vgg_val)
                loss_G = (args.weight_vgg_hv *
                          loss_vgg_hv) + (args.weight_gan * loss_GAN)
            elif args.method == 'M7':
                loss_hv_psnr = hv_loss(1 - (psnr_val / args.max_psnr),
                                       1 - ssim_val)
                if (epoch - start_epoch) < args.loss_mem:
                    loss_G = (loss_content * args.weight_vgg) + (
                        loss_hv_psnr * args.weight_hv) + (args.weight_gan *
                                                          loss_GAN)
                else:
                    weight_vgg = (1 / losses_log[-args.loss_mem:, 1].mean()
                                  ) * args.mem_vgg_weight
                    weight_bce = (1 / losses_log[-args.loss_mem:, 2].mean()
                                  ) * args.mem_bce_weight
                    weight_hv = (1 / losses_log[-args.loss_mem:,
                                                3].mean()) * args.mem_hv_weight
                    loss_G = (loss_content * weight_vgg) + (
                        loss_hv_psnr * weight_hv) + (loss_GAN * weight_bce)
            elif args.method == "M2":
                loss_G = hv_loss(loss_GAN * args.weight_gan,
                                 loss_content * args.weight_vgg)
            elif args.method == 'M3':
                loss_G = (args.weight_vgg * loss_content) + (
                    args.weight_hv * hv_loss(loss_GAN * args.weight_gan,
                                             loss_GAN2 * args.weight_gan))

            if args.include_l1:
                loss_G += (args.weight_l1 * l1_loss(gen_hr, imgs_hr))

            loss_G.backward()
            optimizer_G.step()
            scheduler_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # optimize discriminator
            # discriminator.train()
            for p in discriminator.parameters():
                p.requires_grad = True

            if args.method == 'M3':
                for p in discriminator2.parameters():
                    p.requires_grad = True

            if args.method == 'M3':
                pred_real2 = discriminator2(imgs_hr)
                pred_fake2 = discriminator2(gen_hr.detach())
                valid2 = torch.ones_like(pred_real2)
                fake2 = torch.zeros_like(pred_real2)
                if args.gan == "RAGAN":
                    # Adversarial loss for real and fake images (relativistic average GAN)
                    loss_real2 = bce_loss(
                        pred_real2 - pred_fake2.mean(0, keepdim=True), valid2)
                    loss_fake2 = bce_loss(
                        pred_fake2 - pred_real2.mean(0, keepdim=True), fake2)
                elif args.gan == "VGAN":
                    # Adversarial loss for real and fake images (vanilla GAN)
                    loss_real2 = bce_loss(pred_real2, valid2)
                    loss_fake2 = bce_loss(pred_fake2, fake2)

                optimizer_D2.zero_grad()
                loss_D2 = (loss_real2 + loss_fake2) / 2
                loss_D2.backward()
                optimizer_D2.step()

            optimizer_D.zero_grad()

            pred_real = discriminator(imgs_hr)
            pred_fake = discriminator(gen_hr.detach())

            if args.gan == "RAGAN":
                # Adversarial loss for real and fake images (relativistic average GAN)
                loss_real = bce_loss(
                    pred_real - pred_fake.mean(0, keepdim=True), valid)
                loss_fake = bce_loss(
                    pred_fake - pred_real.mean(0, keepdim=True), fake)
            elif args.gan == "VGAN":
                # Adversarial loss for real and fake images (vanilla GAN)
                loss_real = bce_loss(pred_real, valid)
                loss_fake = bce_loss(pred_fake, fake)

            # Total loss
            loss_D = (loss_real + loss_fake) / 2

            if args.method == 'M7':
                if (epoch - start_epoch) >= args.loss_mem:
                    weight_dis = (1 / losses_log[-args.loss_mem:, 5].mean()
                                  ) * args.mem_bce_weight
                    loss_D = loss_D / weight_dis

            loss_D.backward()
            optimizer_D.step()

            if args.method == "M4" or args.method == "M7":
                losses_gen += np.array([
                    loss_content.item(),
                    loss_GAN.item(),
                    loss_hv_psnr.item(),
                    loss_G.item(),
                    loss_D.item(),
                    psnr_val.mean().item(),
                    ssim_val.mean().item(),
                ])
            elif args.method == "M6":
                losses_gen += np.array([
                    loss_content.item(),
                    loss_GAN.item(),
                    psnr_vgg_val.mean().item(),
                    ssim_vgg_val.mean().item(),
                    loss_vgg_hv.item(),
                    loss_G.item(),
                    loss_D.item(),
                    psnr_val.mean().item(),
                    ssim_val.mean().item(),
                ])
            else:
                losses_gen += np.array([
                    loss_content.item(),
                    loss_GAN.item(),
                    loss_G.item(),
                    loss_D.item(),
                    psnr_val.mean().item(),
                    ssim_val.mean().item(),
                ])

        losses_gen /= batch_count
        losses_gen = list(losses_gen)
        losses_gen.insert(0, epoch)

        write_to_csv_file(os.path.join(args.output_path, 'train_log.csv'),
                          losses_gen)

        if (losses_log == np.zeros(loss_len + 1)).sum() == loss_len + 1:
            losses_log = np.expand_dims(np.array(losses_gen), 0)
        else:
            losses_log = np.vstack((losses_log, losses_gen))

        if epoch % args.print_every == 0:
            print('Epoch', epoch, 'Loss GAN :', losses_gen)

        if epoch % args.plot_every == 0:
            plot_image(epoch, generator, test_dataloader)

        if epoch % args.test_every == 0:
            test(epoch, generator, test_dataloader)

        if epoch % args.save_model_every == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'gen_state_dict': generator.state_dict(),
                'gen_optimizer_dict': optimizer_G.state_dict(),
                'gen_scheduler_dict': scheduler_G.state_dict(),
                'dis_state_dict': discriminator.state_dict(),
                'dis_optimizer_dict': optimizer_D.state_dict(),
            }
            os.makedirs(os.path.join(args.output_path, 'saved_model'),
                        exist_ok=True)
            torch.save(
                checkpoint,
                os.path.join(args.output_path, 'saved_model',
                             'checkpoint_' + str(epoch) + ".pth"))
Beispiel #14
0
from tqdm import tqdm


def preprocess_data(img_input, au_input):
    au = tf.expand_dims(au_input, axis=1, name='expand_dims1')  #[None, 1, 17]
    au = tf.expand_dims(au, axis=2, name='expand_dims2')  #[None, 1, 1, 17]
    au = tf.tile(au, multiples=[1, 128, 128, 1],
                 name='tile')  #[None, 128, 128, 17]
    x = tf.concat([img_input, au], axis=3,
                  name='concat')  #[None, 128, 128, 20]
    return x


if __name__ == '__main__':
    G = get_generator([None, 128, 128, 20])
    D = get_discriminator([None, 128, 128, 3])

    G_path = 'G_30.npz'
    tl.files.load_and_assign_npz(name=G_path, network=G)
    D_path = 'D_30.npz'
    tl.files.load_and_assign_npz(name=D_path, network=D)

    imgs_names = os.listdir('test_face')
    real_src = face_recognition.load_image_file('test.jpeg')  # RGB image
    face_loc = face_recognition.face_locations(real_src)

    top, right, bottom, left = face_loc[0]
    if len(face_loc) == 1:
        top, right, bottom, left = face_loc[0]

    real_face = np.zeros((1, 128, 128, 3), dtype=np.float32)
Beispiel #15
0
import pickle

def maskBig(x, target, threshold):
    y = x * (target-0.5)
    x[y>threshold] = 0.0
    return x

if __name__ == '__main__':
    #parser.add_argument('-save', type=str, default = './checkpoint/test/', help='place to save')
    _path = ''#'/content/drive/My Drive/Colab Notebooks/myblast/'
    
    config = configparser.ConfigParser()
    config.read(_path+'mixed_15720.ini')
    #gpu_tracker.track()
    encoder = model.get_encoder(config, "M")
    discriminator = model.get_discriminator(config)
    generator = model.get_generator(config)
    if torch.cuda.is_available():
        encoder = encoder.cuda()
        discriminator = discriminator.cuda()
        generator = generator.cuda()
    #classifier = model.get_classifier(config).cuda()
    #gpu_tracker.track()
    #optimC = optim.Adam(classifier.parameters(), lr=config.getfloat('training', 'lr'))
    optimE = optim.Adam(encoder.parameters(), lr=config.getfloat('training', 'lr')*0.01) 
    optimG = optim.Adam(generator.parameters(), lr=config.getfloat('training', 'lr'))
    optimD = optim.Adam(discriminator.parameters(), lr=config.getfloat('training', 'lr'))

    '''
    Quake_Smart_seq2 = data.read_dataset(_path+"../data/Quake_Smart-seq2/data.h5")
    Quake_10x = data.read_dataset(_path+"../data/Quake_10x/data.h5")
Beispiel #16
0
def train():
    images, images_path = get_celebA(FLAGS.output_size, FLAGS.n_epoch,
                                     FLAGS.batch_size)
    G = get_generator([None, FLAGS.z_dim])
    D = get_discriminator(
        [None, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim])

    G.train()
    D.train()

    d_optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate,
                                         beta1=FLAGS.beta1)
    g_optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate,
                                         beta1=FLAGS.beta1)

    n_step_epoch = int(len(images_path) // FLAGS.batch_size)

    for step, batch_images in enumerate(images):
        step_time = time.time()

        with tf.GradientTape(persistent=True) as tape:
            z = tf.contrib.distributions.Normal(0., 1.).sample([
                FLAGS.batch_size, FLAGS.z_dim
            ])  #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
            d_logits = D(G(z))
            d2_logits = D(batch_images)
            d_loss_real = tl.cost.sigmoid_cross_entropy(
                d2_logits, tf.ones_like(d2_logits), name='real')
            d_loss_fake = tl.cost.sigmoid_cross_entropy(
                d_logits, tf.zeros_like(d_logits), name='fake')

        grad_gd = tape.gradient(d_loss_fake, G.weights + D.weights)
        grad_d1 = tape.gradient(d_loss_real, D.weights)
        scale = -1  #tf.reduce_mean(sigmoid(d_logits)/(sigmoid(d_logits)-1))
        grad_g = grad_gd[0:len(G.weights)]
        for i in range(len(grad_g)):
            if grad_g[i] != None:  # batch_norm moving mean, var
                grad_g[i] = grad_g[i] * scale
            # grad_d1 = list(filter(lambda x: correct_grad(x, scale), grad_d1))
        grad_d2 = grad_gd[len(G.weights):]
        grad_d = []
        for x, y in zip(grad_d1, grad_d2):
            if x == None:  # batch_norm moving mean, var
                grad_d.append(None)
            else:
                grad_d.append(x + y)
        g_optimizer.apply_gradients(zip(grad_g, G.weights))
        d_optimizer.apply_gradients(zip(grad_d, D.weights))
        del tape

        g_loss = d_loss_fake
        d_loss = d_loss_real + d_loss_fake

        print(
            "Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".
            format(step // n_step_epoch, FLAGS.n_epoch, step, n_step_epoch,
                   time.time() - step_time, d_loss, g_loss))
        if np.mod(step, FLAGS.save_step) == 0:
            G.save_weights('{}/G.npz'.format(FLAGS.checkpoint_dir),
                           format='npz')
            D.save_weights('{}/D.npz'.format(FLAGS.checkpoint_dir),
                           format='npz')
            result = G(z)
            tl.visualize.save_images(
                result.numpy(), [num_tiles, num_tiles],
                '{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir,
                                                    step // n_step_epoch,
                                                    step))
Beispiel #17
0
    test_A.map(preprocess_test_image, num_parallel_calls=autotune).cache()
    #.shuffle(buffer_size)
    .batch(batch_size))
# test_B = (
#     test_B.map(preprocess_test_image, num_parallel_calls=autotune)
#     .cache()
#     #.shuffle(buffer_size)
#     .batch(batch_size)
# )

# Get the generators
gen_G = get_resnet_generator(name="generator_G")
gen_F = get_resnet_generator(name="generator_F")

# Get the discriminators
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")


class CycleGan(keras.Model):
    def __init__(self, generator_G, generator_F, discriminator_X,
                 discriminator_Y):
        super(CycleGan, self).__init__()
        self.gen_G = generator_G
        self.gen_F = generator_F
        self.disc_X = discriminator_X
        self.disc_Y = discriminator_Y
        # self.lambda_cycle = lambda_cycle
        # self.lambda_identity = lambda_identity
        # self.lambda_gamma = lambda_gamma
Beispiel #18
0
def train():
    images, images_path = get_celebA(flags.output_size, flags.n_epoch,
                                     flags.batch_size)
    G = get_generator([None, flags.z_dim])
    D = get_discriminator(
        [None, flags.output_size, flags.output_size, flags.c_dim])

    G.train()
    D.train()

    d_optimizer = tf.optimizers.Adam(flags.learning_rate, beta_1=flags.beta1)
    g_optimizer = tf.optimizers.Adam(flags.learning_rate, beta_1=flags.beta1)

    n_step_epoch = int(len(images_path) // flags.batch_size)

    for step, batch_images in enumerate(images):
        step_time = time.time()
        with tf.GradientTape(persistent=True) as tape:
            # z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
            z = np.random.normal(loc=0.0,
                                 scale=1.0,
                                 size=[flags.batch_size,
                                       flags.z_dim]).astype(np.float32)
            d_logits = D(G(z))
            d2_logits = D(batch_images)
            # discriminator: real images are labelled as 1
            d_loss_real = tl.cost.sigmoid_cross_entropy(
                d2_logits, tf.ones_like(d2_logits), name='dreal')
            # discriminator: images from generator (fake) are labelled as 0
            d_loss_fake = tl.cost.sigmoid_cross_entropy(
                d_logits, tf.zeros_like(d_logits), name='dfake')
            # combined loss for updating discriminator
            d_loss = d_loss_real + d_loss_fake
            # generator: try to fool discriminator to output 1
            g_loss = tl.cost.sigmoid_cross_entropy(d_logits,
                                                   tf.ones_like(d_logits),
                                                   name='gfake')

        grad = tape.gradient(g_loss, G.weights)
        g_optimizer.apply_gradients(zip(grad, G.weights))
        grad = tape.gradient(d_loss, D.weights)
        d_optimizer.apply_gradients(zip(grad, D.weights))
        del tape

        print(
            "Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".
            format(step // n_step_epoch, flags.n_epoch, step, n_step_epoch,
                   time.time() - step_time, d_loss, g_loss))
        if np.mod(step, flags.save_step) == 0:
            G.save_weights('{}/G.npz'.format(flags.checkpoint_dir),
                           format='npz')
            D.save_weights('{}/D.npz'.format(flags.checkpoint_dir),
                           format='npz')
            G.eval()
            result = G(z)
            G.train()
            tl.visualize.save_images(
                result.numpy(), [num_tiles, num_tiles],
                '{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir,
                                                    step // n_step_epoch,
                                                    step))
Beispiel #19
0
def train():
    images, images_path = get_celebA(FLAGS.output_size, FLAGS.n_epoch,
                                     FLAGS.batch_size)
    G = get_generator([None, FLAGS.z_dim])
    D = get_discriminator(
        [None, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim])

    G.train()
    D.train()

    d_optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate,
                                         beta1=FLAGS.beta1)
    g_optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate,
                                         beta1=FLAGS.beta1)

    n_step_epoch = int(len(images_path) // FLAGS.batch_size)
    step_time = time.time()
    for step, batch_images in enumerate(images):
        #step_time = time.time()
        with tf.GradientTape(persistent=True) as tape:
            z = tf.contrib.distributions.Normal(0., 1.).sample([
                FLAGS.batch_size, FLAGS.z_dim
            ])  #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
            d_logits = D(G(z))
            d2_logits = D(batch_images)
            # discriminator: real images are labelled as 1
            d_loss_real = tl.cost.sigmoid_cross_entropy(
                d2_logits, tf.ones_like(d2_logits), name='dreal')
            # discriminator: images from generator (fake) are labelled as 0
            d_loss_fake = tl.cost.sigmoid_cross_entropy(
                d_logits, tf.zeros_like(d_logits), name='dfake')
            # combined loss for updating discriminator
            d_loss = d_loss_real + d_loss_fake
            # generator: try to fool discriminator to output 1
            g_loss = tl.cost.sigmoid_cross_entropy(d_logits,
                                                   tf.ones_like(d_logits),
                                                   name='gfake')

        grad = tape.gradient(g_loss, G.weights)
        g_optimizer.apply_gradients(zip(grad, G.weights))
        grad = tape.gradient(d_loss, D.weights)
        d_optimizer.apply_gradients(zip(grad, D.weights))
        del tape

        #print("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".format(step//n_step_epoch, FLAGS.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss))
        if np.mod(step, n_step_epoch) == 0:
            fid = tf.contrib.gan.eval.frechet_classifier_distance(
                batch_images, G(z), D, num_batches=8)
            print(
                "Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}, fid: {:5f}"
                .format(step // n_step_epoch, FLAGS.n_epoch, step,
                        n_step_epoch,
                        time.time() - step_time, d_loss, g_loss, fid))
            step_time = time.time()

        if np.mod(step, FLAGS.save_step) == 0:
            G.save_weights('{}/G.npz'.format(FLAGS.checkpoint_dir),
                           format='npz')
            D.save_weights('{}/D.npz'.format(FLAGS.checkpoint_dir),
                           format='npz')
            result = G(z)
            tl.visualize.save_images(
                result.numpy(), [num_tiles, num_tiles],
                '{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir,
                                                    step // n_step_epoch,
                                                    step))

    fid = tf.contrib.gan.eval.frechet_classifier_distance(batch_images,
                                                          G(z),
                                                          D,
                                                          num_batches=8)
    print(
        "Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}, fid: {:5f}"
        .format(step // n_step_epoch, FLAGS.n_epoch, step, n_step_epoch,
                time.time() - step_time, d_loss, g_loss, fid))
    result = G(z).numpy()
    tl.visualize.save_images(
        result, [num_tiles, num_tiles],
        '{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir,
                                            step // n_step_epoch, step))
    for i in range(result.shape[0]):
        tl.visualize.save_image(
            result[i, :, :, :],
            '{}/train_{:02d}.png'.format(FLAGS.sample_dir, i))