def build_wgan_gp(self): # find shapes for inputs cond_shapes = input_shapes(self.gen, "cond") noise_shapes = input_shapes(self.gen, "noise") sample_shapes = input_shapes(self.disc, "sample") # Create generator training network with Nontrainable(self.disc): cond_in = [Input(shape=s) for s in cond_shapes] noise_in = [Input(shape=s) for s in noise_shapes] gen_in = cond_in+noise_in gen_out = self.gen(gen_in) gen_out = ensure_list(gen_out) disc_in_gen = cond_in+[gen_out] disc_out_gen = self.disc(disc_in_gen) self.gen_trainer = Model(inputs=gen_in, outputs=disc_out_gen) # Create discriminator training network with Nontrainable(self.gen): cond_in = [Input(shape=s) for s in cond_shapes] noise_in = [Input(shape=s) for s in noise_shapes] sample_in = [Input(shape=s) for s in sample_shapes] gen_in = cond_in+noise_in disc_in_real = sample_in[0] disc_in_fake = self.gen(gen_in) disc_in_avg = RandomWeightedAverage()([disc_in_real,disc_in_fake]) disc_out_real = self.disc(cond_in+[disc_in_real]) disc_out_fake = self.disc(cond_in+[disc_in_fake]) disc_out_avg = self.disc(cond_in+[disc_in_avg]) disc_gp = GradientPenalty()([disc_out_avg, disc_in_avg]) self.disc_trainer = Model(inputs=cond_in+sample_in+noise_in, outputs=[disc_out_real,disc_out_fake,disc_gp]) self.compile()
def load(self, load_files): self.gen.load_weights(load_files["gen_weights"]) self.disc.load_weights(load_files["disc_weights"]) with Nontrainable(self.disc): self.gen_trainer._make_train_function() load_opt_weights(self.gen_trainer, load_files["gen_opt_weights"]) with Nontrainable(self.gen): self.disc_trainer._make_train_function() load_opt_weights(self.disc_trainer, load_files["disc_opt_weights"])
def compile(self, opt_disc=None, opt_gen=None): #create optimizers if opt_disc is None: opt_disc = Adam(self.lr_disc, beta_1=0.5, beta_2=0.9) self.opt_disc = opt_disc if opt_gen is None: opt_gen = Adam(self.lr_gen, beta_1=0.5, beta_2=0.9) self.opt_gen = opt_gen with Nontrainable(self.disc): self.gen_trainer.compile(loss=wasserstein_loss, optimizer=self.opt_gen) with Nontrainable(self.gen): self.disc_trainer.compile( loss=[wasserstein_loss, wasserstein_loss, 'mse'], loss_weights=[1.0, 1.0, self.gradient_penalty_weight], optimizer=self.opt_disc )
def fit_generator(self, batch_gen, noise_gen, steps_per_epoch=1, num_epochs=1, training_ratio=1): disc_out_shape = (batch_gen.batch_size, self.disc.output_shape[1]) real_target = np.ones(disc_out_shape, dtype=np.float32) fake_target = -real_target gp_target = np.zeros_like(real_target) for epoch in range(num_epochs): print("Epoch {}/{}".format(epoch + 1, num_epochs)) # Initialize progbar and batch counter progbar = generic_utils.Progbar(steps_per_epoch * batch_gen.batch_size) for step in range(steps_per_epoch): # Train discriminator with Nontrainable(self.gen): for repeat in range(training_ratio): image_batch = next(batch_gen) noise_batch = next(noise_gen) disc_loss = self.disc_trainer.train_on_batch( [image_batch] + noise_batch, [real_target, fake_target, gp_target]) # Train generator with Nontrainable(self.disc): noise_batch = next(noise_gen) gen_loss = self.gen_trainer.train_on_batch( noise_batch, real_target) losses = [] for (i, dl) in enumerate(disc_loss): losses.append(("D{}".format(i), dl)) losses.append(("G0", gen_loss)) progbar.add(batch_gen.batch_size, values=losses)
def build(self): img_shape = input_shapes(self.disc, "sample_in")[0] cond_shapes = input_shapes(self.gen, "cond_in") noise_shapes = input_shapes(self.gen, "noise_in") # Create optimizers self.opt_disc = Adam(self.lr_disc, beta_1=0.5, beta_2=0.9) self.opt_gen = Adam(self.lr_gen, beta_1=0.5, beta_2=0.9) # Build discriminator training network with Nontrainable(self.gen): real_image = Input(shape=img_shape) cond = [Input(shape=s) for s in cond_shapes] noise = [Input(shape=s) for s in noise_shapes] disc_real = self.disc([real_image]+cond) generated_image = self.gen([cond]+[noise]) disc_fake = self.disc([generated_image]+cond) self.disc_trainer = Model( inputs=[real_image]+cond+noise, outputs=[disc_real, disc_fake] ) self.disc_trainer.compile(optimizer=self.opt_disc, loss=["binary_crossentropy", "binary_crossentropy"]) # Build generator training network with Nontrainable(self.disc): cond = [Input(shape=s) for s in cond_shapes] noise = [Input(shape=s) for s in noise_shapes] generated_image = self.gen(cond+noise) disc_fake = self.disc([generated_image]+cond) self.gen_trainer = Model( inputs=cond+noise, outputs=disc_fake ) self.gen_trainer.compile(optimizer=self.opt_gen, loss="binary_crossentropy")
def build(self): img_shape = input_shapes(self.disc, "sample_in")[0] noise_shapes = input_shapes(self.gen, "noise_in") # Create optimizers self.opt_disc = Adam(self.lr_disc, beta_1=0.5, beta_2=0.9) self.opt_gen = Adam(self.lr_gen, beta_1=0.5, beta_2=0.9) # Build discriminator training network with Nontrainable(self.gen): real_image = Input(shape=img_shape) noise = [Input(shape=s) for s in noise_shapes] disc_real = self.disc(real_image) generated_image = self.gen(noise) disc_fake = self.disc(generated_image) avg_image = RandomWeightedAverage()([real_image, generated_image]) disc_avg = self.disc(avg_image) gp = GradientPenalty()([disc_avg, avg_image]) self.disc_trainer = Model(inputs=[real_image, noise], outputs=[disc_real, disc_fake, gp]) self.disc_trainer.compile( optimizer=self.opt_disc, loss=[wasserstein_loss, wasserstein_loss, 'mse'], loss_weights=[1.0, 1.0, 10.0]) # Build generator training network with Nontrainable(self.disc): noise = [Input(shape=s) for s in noise_shapes] generated_image = self.gen(noise) disc_fake = self.disc(generated_image) self.gen_trainer = Model(inputs=noise, outputs=disc_fake) self.gen_trainer.compile(optimizer=self.opt_gen, loss=wasserstein_loss)
def train(self, batch_gen, noise_gen, num_gen_batches=1, training_ratio=1, show_progress=True): disc_target_real = None if show_progress: # Initialize progbar and batch counter progbar = generic_utils.Progbar( num_gen_batches*batch_gen.batch_size) disc_target_real = np.ones( (batch_gen.batch_size, batch_gen.num_frames, 1), dtype=np.float32) disc_target_fake = -disc_target_real gen_target = disc_target_real target_gp = np.zeros((batch_gen.batch_size, 1), dtype=np.float32) disc_target = [disc_target_real, disc_target_fake, target_gp] loss_log = [] for k in range(num_gen_batches): # train discriminator disc_loss = None disc_loss_n = 0 for rep in range(training_ratio): # generate some real samples (sample, cond) = next(batch_gen) noise = noise_gen() with Nontrainable(self.gen): dl = self.disc_trainer.train_on_batch( [cond,sample]+noise, disc_target) if disc_loss is None: disc_loss = np.array(dl) else: disc_loss += np.array(dl) disc_loss_n += 1 del sample, cond disc_loss /= disc_loss_n with Nontrainable(self.disc): (sample, cond) = next(batch_gen) gen_loss = self.gen_trainer.train_on_batch( [cond]+noise_gen(), gen_target) del sample, cond if show_progress: losses = [] for (i,dl) in enumerate(disc_loss): losses.append(("D{}".format(i), dl)) for (i,gl) in enumerate([gen_loss]): losses.append(("G{}".format(i), gl)) progbar.add(batch_gen.batch_size, values=losses) loss_log.append(np.hstack((disc_loss,gen_loss))) gc.collect() return np.array(loss_log)