示例#1
0
    def build_wgan_gp(self):

        # find shapes for inputs
        cond_shapes = input_shapes(self.gen, "cond")
        noise_shapes = input_shapes(self.gen, "noise")
        sample_shapes = input_shapes(self.disc, "sample")

        # Create generator training network
        with Nontrainable(self.disc):
            cond_in = [Input(shape=s) for s in cond_shapes]
            noise_in = [Input(shape=s) for s in noise_shapes]
            gen_in = cond_in+noise_in
            gen_out = self.gen(gen_in)
            gen_out = ensure_list(gen_out)
            disc_in_gen = cond_in+[gen_out]
            disc_out_gen = self.disc(disc_in_gen)
            self.gen_trainer = Model(inputs=gen_in, outputs=disc_out_gen)

        # Create discriminator training network
        with Nontrainable(self.gen):
            cond_in = [Input(shape=s) for s in cond_shapes]
            noise_in = [Input(shape=s) for s in noise_shapes]
            sample_in = [Input(shape=s) for s in sample_shapes]
            gen_in = cond_in+noise_in
            disc_in_real = sample_in[0]
            disc_in_fake = self.gen(gen_in) 
            disc_in_avg = RandomWeightedAverage()([disc_in_real,disc_in_fake])
            disc_out_real = self.disc(cond_in+[disc_in_real])
            disc_out_fake = self.disc(cond_in+[disc_in_fake])
            disc_out_avg = self.disc(cond_in+[disc_in_avg])
            disc_gp = GradientPenalty()([disc_out_avg, disc_in_avg])
            self.disc_trainer = Model(inputs=cond_in+sample_in+noise_in,
                outputs=[disc_out_real,disc_out_fake,disc_gp])

        self.compile()
示例#2
0
 def load(self, load_files):
     self.gen.load_weights(load_files["gen_weights"])
     self.disc.load_weights(load_files["disc_weights"])
     
     with Nontrainable(self.disc):
         self.gen_trainer._make_train_function()
         load_opt_weights(self.gen_trainer,
             load_files["gen_opt_weights"])
     with Nontrainable(self.gen):
         self.disc_trainer._make_train_function()
         load_opt_weights(self.disc_trainer,
             load_files["disc_opt_weights"])
示例#3
0
    def compile(self, opt_disc=None, opt_gen=None):
        #create optimizers
        if opt_disc is None:
            opt_disc = Adam(self.lr_disc, beta_1=0.5, beta_2=0.9)
        self.opt_disc = opt_disc
        if opt_gen is None:
            opt_gen = Adam(self.lr_gen, beta_1=0.5, beta_2=0.9)
        self.opt_gen = opt_gen

        with Nontrainable(self.disc):
            self.gen_trainer.compile(loss=wasserstein_loss,
                optimizer=self.opt_gen)
        with Nontrainable(self.gen):
            self.disc_trainer.compile(
                loss=[wasserstein_loss, wasserstein_loss, 'mse'], 
                loss_weights=[1.0, 1.0, self.gradient_penalty_weight],
                optimizer=self.opt_disc
            )
示例#4
0
    def fit_generator(self,
                      batch_gen,
                      noise_gen,
                      steps_per_epoch=1,
                      num_epochs=1,
                      training_ratio=1):

        disc_out_shape = (batch_gen.batch_size, self.disc.output_shape[1])
        real_target = np.ones(disc_out_shape, dtype=np.float32)
        fake_target = -real_target
        gp_target = np.zeros_like(real_target)

        for epoch in range(num_epochs):

            print("Epoch {}/{}".format(epoch + 1, num_epochs))
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(steps_per_epoch *
                                            batch_gen.batch_size)

            for step in range(steps_per_epoch):

                # Train discriminator
                with Nontrainable(self.gen):
                    for repeat in range(training_ratio):
                        image_batch = next(batch_gen)
                        noise_batch = next(noise_gen)
                        disc_loss = self.disc_trainer.train_on_batch(
                            [image_batch] + noise_batch,
                            [real_target, fake_target, gp_target])

                # Train generator
                with Nontrainable(self.disc):
                    noise_batch = next(noise_gen)
                    gen_loss = self.gen_trainer.train_on_batch(
                        noise_batch, real_target)

                losses = []
                for (i, dl) in enumerate(disc_loss):
                    losses.append(("D{}".format(i), dl))
                losses.append(("G0", gen_loss))
                progbar.add(batch_gen.batch_size, values=losses)
示例#5
0
    def build(self):
        img_shape = input_shapes(self.disc, "sample_in")[0]
        cond_shapes = input_shapes(self.gen, "cond_in")
        noise_shapes = input_shapes(self.gen, "noise_in")

        # Create optimizers
        self.opt_disc = Adam(self.lr_disc, beta_1=0.5, beta_2=0.9)
        self.opt_gen = Adam(self.lr_gen, beta_1=0.5, beta_2=0.9)

        # Build discriminator training network
        with Nontrainable(self.gen):
            real_image = Input(shape=img_shape)
            cond = [Input(shape=s) for s in cond_shapes]
            noise = [Input(shape=s) for s in noise_shapes]
            
            disc_real = self.disc([real_image]+cond)
            generated_image = self.gen([cond]+[noise])
            disc_fake = self.disc([generated_image]+cond)

            self.disc_trainer = Model(
                inputs=[real_image]+cond+noise,
                outputs=[disc_real, disc_fake]
            )
            self.disc_trainer.compile(optimizer=self.opt_disc,
                loss=["binary_crossentropy", "binary_crossentropy"])

        # Build generator training network
        with Nontrainable(self.disc):
            cond = [Input(shape=s) for s in cond_shapes]
            noise = [Input(shape=s) for s in noise_shapes]
            
            generated_image = self.gen(cond+noise)
            disc_fake = self.disc([generated_image]+cond)
            
            self.gen_trainer = Model(
                inputs=cond+noise, 
                outputs=disc_fake
            )
            self.gen_trainer.compile(optimizer=self.opt_gen,
                loss="binary_crossentropy")
示例#6
0
    def build(self):
        img_shape = input_shapes(self.disc, "sample_in")[0]
        noise_shapes = input_shapes(self.gen, "noise_in")

        # Create optimizers
        self.opt_disc = Adam(self.lr_disc, beta_1=0.5, beta_2=0.9)
        self.opt_gen = Adam(self.lr_gen, beta_1=0.5, beta_2=0.9)

        # Build discriminator training network
        with Nontrainable(self.gen):
            real_image = Input(shape=img_shape)
            noise = [Input(shape=s) for s in noise_shapes]

            disc_real = self.disc(real_image)
            generated_image = self.gen(noise)
            disc_fake = self.disc(generated_image)

            avg_image = RandomWeightedAverage()([real_image, generated_image])
            disc_avg = self.disc(avg_image)
            gp = GradientPenalty()([disc_avg, avg_image])

            self.disc_trainer = Model(inputs=[real_image, noise],
                                      outputs=[disc_real, disc_fake, gp])
            self.disc_trainer.compile(
                optimizer=self.opt_disc,
                loss=[wasserstein_loss, wasserstein_loss, 'mse'],
                loss_weights=[1.0, 1.0, 10.0])

        # Build generator training network
        with Nontrainable(self.disc):
            noise = [Input(shape=s) for s in noise_shapes]

            generated_image = self.gen(noise)
            disc_fake = self.disc(generated_image)

            self.gen_trainer = Model(inputs=noise, outputs=disc_fake)
            self.gen_trainer.compile(optimizer=self.opt_gen,
                                     loss=wasserstein_loss)
示例#7
0
    def train(self, batch_gen, noise_gen, num_gen_batches=1, 
        training_ratio=1, show_progress=True):

        disc_target_real = None
        if show_progress:
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(
                num_gen_batches*batch_gen.batch_size)

        disc_target_real = np.ones(
            (batch_gen.batch_size, batch_gen.num_frames, 1), dtype=np.float32)
        disc_target_fake = -disc_target_real
        gen_target = disc_target_real
        target_gp = np.zeros((batch_gen.batch_size, 1), dtype=np.float32)
        disc_target = [disc_target_real, disc_target_fake, target_gp]

        loss_log = []

        for k in range(num_gen_batches):
        
            # train discriminator
            disc_loss = None
            disc_loss_n = 0
            for rep in range(training_ratio):
                # generate some real samples
                (sample, cond) = next(batch_gen)
                noise = noise_gen()

                with Nontrainable(self.gen):   
                    dl = self.disc_trainer.train_on_batch(
                        [cond,sample]+noise, disc_target)

                if disc_loss is None:
                    disc_loss = np.array(dl)
                else:
                    disc_loss += np.array(dl)
                disc_loss_n += 1

                del sample, cond

            disc_loss /= disc_loss_n

            with Nontrainable(self.disc):
                (sample, cond) = next(batch_gen)
                gen_loss = self.gen_trainer.train_on_batch(
                    [cond]+noise_gen(), gen_target)
                del sample, cond

            if show_progress:
                losses = []
                for (i,dl) in enumerate(disc_loss):
                    losses.append(("D{}".format(i), dl))
                for (i,gl) in enumerate([gen_loss]):
                    losses.append(("G{}".format(i), gl))
                progbar.add(batch_gen.batch_size, 
                    values=losses)

            loss_log.append(np.hstack((disc_loss,gen_loss)))

            gc.collect()

        return np.array(loss_log)