Ejemplo n.º 1
0
    def viz_img_from_z_dist(self, z_dist, epoch):
        n_classes = self.opts['n_classes']
        y = np.repeat(np.arange(0, n_classes).reshape(n_classes, 1), 10)
        z = self.sample_pz(len(y), z_dist, y)

        sample_gen = self.sess.run(
            self.decoded,
            feed_dict={self.sample_noise: z,
                       self.is_training: False})
        sample_gen = sample_gen.reshape(self.opts['n_classes'],
                                        sample_gen.shape[0] // self.opts['n_classes'],
                                        sample_gen.shape[-3],
                                        sample_gen.shape[-2],
                                        sample_gen.shape[-1])
        utils.save_image_array(sample_gen, self.opts['work_dir'] + os.sep + "img_from_z_dist_{}.png".format(epoch))
Ejemplo n.º 2
0
    def train(self, bg_train, bg_test, epochs=50):
        if not self.trained:
            self.autoenc_epochs = epochs

            # Class actual ratio
            self.class_aratio = bg_train.get_class_probability()

            # Class balancing ratio
            self._set_class_ratios()
            print("uratio set to: {}".format(self.class_uratio))
            print("dratio set to: {}".format(self.class_dratio))
            print("gratio set to: {}".format(self.class_gratio))

            # Initialization
            print("BAGAN init_autoenc")
            self.init_autoenc(bg_train)
            print("BAGAN autoenc initialized, init gan")
            start_e = self.init_gan()
            print("BAGAN gan initialized, start_e: ", start_e)

            crt_c = 0
            act_img_samples = bg_train.get_samples_for_class(crt_c, 10)
            img_samples = np.array([[
                act_img_samples,
                self.generator.predict(
                    self.reconstructor.predict(act_img_samples)),
                self.generate_samples(crt_c, 10, bg_train)
            ]])
            for crt_c in range(1, self.nclasses):
                act_img_samples = bg_train.get_samples_for_class(crt_c, 10)
                new_samples = np.array([[
                    act_img_samples,
                    self.generator.predict(
                        self.reconstructor.predict(act_img_samples)),
                    self.generate_samples(crt_c, 10, bg_train)
                ]])
                img_samples = np.concatenate((img_samples, new_samples),
                                             axis=0)

            shape = img_samples.shape
            img_samples = img_samples.reshape(
                (-1, shape[-4], shape[-3], shape[-2], shape[-1]))

            save_image_array(
                img_samples,
                '{}/cmp_class_{}_init.png'.format(self.res_dir,
                                                  self.target_class_id))

            # Train
            for e in range(start_e, epochs):
                print('Epoch {} of {}'.format(self.dratio_mode,
                                              self.gratio_mode, e + 1, epochs))
                # train_disc_loss, train_gen_loss = self._train_one_epoch(copy.deepcopy(bg_train))
                train_disc_loss, train_gen_loss = self._train_one_epoch(
                    bg_train)

                # Test: # generate a new batch of noise
                nb_test = bg_test.get_num_samples()
                fake_size = int(np.ceil(nb_test * 1.0 / self.nclasses))
                sampled_labels = self._biased_sample_labels(nb_test, "d")
                latent_gen = self.generate_latent(sampled_labels, bg_test)

                # sample some labels from p_c and generate images from them
                generated_images = self.generator.predict(latent_gen,
                                                          verbose=False)

                X = np.concatenate((bg_test.dataset_x, generated_images))
                aux_y = np.concatenate(
                    (bg_test.dataset_y,
                     np.full(len(sampled_labels), self.nclasses)),
                    axis=0)

                # see if the discriminator can figure itself out...
                test_disc_loss = self.discriminator.evaluate(X,
                                                             aux_y,
                                                             verbose=False)

                # make new latent
                sampled_labels = self._biased_sample_labels(
                    fake_size + nb_test, "g")
                latent_gen = self.generate_latent(sampled_labels, bg_test)

                test_gen_loss = self.combined.evaluate(latent_gen,
                                                       sampled_labels,
                                                       verbose=False)

                # generate an epoch report on performance
                self.train_history['disc_loss'].append(train_disc_loss)
                self.train_history['gen_loss'].append(train_gen_loss)
                self.test_history['disc_loss'].append(test_disc_loss)
                self.test_history['gen_loss'].append(test_gen_loss)
                print(
                    "train_disc_loss {},\ttrain_gen_loss {},\ttest_disc_loss {},\ttest_gen_loss {}"
                    .format(train_disc_loss, train_gen_loss, test_disc_loss,
                            test_gen_loss))

                # Save sample images
                if e % 10 == 9:
                    img_samples = np.array([
                        self.generate_samples(c, 10, bg_train)
                        for c in range(0, self.nclasses)
                    ])

                    save_image_array(
                        img_samples, '{}/plot_class_{}_epoch_{}.png'.format(
                            self.res_dir, self.target_class_id, e))

                # Generate whole evaluation plot (real img, autoencoded img, fake img)
                if e % 10 == 5:
                    self.backup_point(e)
                    crt_c = 0
                    act_img_samples = bg_train.get_samples_for_class(crt_c, 10)
                    img_samples = np.array([[
                        act_img_samples,
                        self.generator.predict(
                            self.reconstructor.predict(act_img_samples)),
                        self.generate_samples(crt_c, 10, bg_train)
                    ]])
                    for crt_c in range(1, self.nclasses):
                        act_img_samples = bg_train.get_samples_for_class(
                            crt_c, 10)
                        new_samples = np.array([[
                            act_img_samples,
                            self.generator.predict(
                                self.reconstructor.predict(act_img_samples)),
                            self.generate_samples(crt_c, 10, bg_train)
                        ]])
                        img_samples = np.concatenate(
                            (img_samples, new_samples), axis=0)

                    shape = img_samples.shape
                    img_samples = img_samples.reshape(
                        (-1, shape[-4], shape[-3], shape[-2], shape[-1]))

                    save_image_array(
                        img_samples, '{}/cmp_class_{}_epoch_{}.png'.format(
                            self.res_dir, self.target_class_id, e))

            self.trained = True
Ejemplo n.º 3
0
        else:  # GAN pre-trained
            # Unbalance the training.
            print("Loading GAN for class {}".format(c))
            bg_train_partial = BatchGenerator(BatchGenerator.TRAIN,
                                              batch_size,
                                              class_to_prune=c,
                                              unbalance=unbalance)

            gan = bagan.BalancingGAN(target_classes,
                                     c,
                                     dratio_mode=dratio_mode,
                                     gratio_mode=gratio_mode,
                                     adam_lr=adam_lr,
                                     res_dir=res_dir,
                                     image_shape=shape,
                                     min_latent_res=min_latent_res)
            gan.load_models(
                "{}/class_{}_generator.h5".format(res_dir, c),
                "{}/class_{}_discriminator.h5".format(res_dir, c),
                "{}/class_{}_reconstructor.h5".format(res_dir, c),
                bg_train=
                bg_train_partial  # This is required to initialize the per-class mean and covariance matrix
            )

        # Sample and save images
        img_samples['class_{}'.format(c)] = gan.generate_samples(c=c,
                                                                 samples=10)

        save_image_array(np.array([img_samples['class_{}'.format(c)]]),
                         '{}/plot_class_{}.png'.format(res_dir, c))
Ejemplo n.º 4
0
    def train(self,
              bg_train,
              bg_test,
              epochs=100,
              class_num=10,
              latent_size=100,
              mode_z='uniform',
              batch_size=100,
              gen_class_ration=[]):
        if not self.trained:
            # Class actual ratio
            self.class_aratio = bg_train.get_class_probability()
            fixed_latent = self.generate_latent(batch_size, latent_size,
                                                mode_z)

            # Train
            start_e = 0
            for e in range(start_e, epochs):
                start_time = time()
                # Train
                print('GAN train epoch: {}/{}'.format(e, epochs))
                train_classifier_loss, train_gen_loss = self._train_one_epoch(
                    bg_train,
                    class_num,
                    batch_size=batch_size,
                    mode_z=mode_z,
                    gen_class_ration=gen_class_ration)

                loss_R = train_gen_loss[0] - train_gen_loss[
                    1] - train_gen_loss[2]
                self.result_logger.add_training_metrics1(
                    float(train_gen_loss[0]), float(train_gen_loss[1]),
                    float(train_gen_loss[2]), float(loss_R),
                    float(train_classifier_loss[0]),
                    float(train_classifier_loss[1]),
                    float(train_classifier_loss[2]),
                    time() - start_time)
                # Test #
                test_loss = self.classifier.evaluate(
                    bg_test.dataset_x, [bg_test.dataset_y, bg_test.dataset_y],
                    verbose=False)
                self.result_logger.add_testing_metrics(test_loss[0],
                                                       test_loss[1],
                                                       test_loss[2])
                probs_0, probs_1 = self.classifier.predict(
                    bg_test.dataset_x, batch_size=batch_size, verbose=True)
                final_probs = probs_1
                predicts = np.argmax(final_probs, axis=-1)
                self.result_logger.save_prediction(e,
                                                   bg_test.dataset_y,
                                                   predicts,
                                                   probs_0,
                                                   probs_1,
                                                   epochs=epochs)
                self.result_logger.save_metrics()
                print("train_classifier_loss {},\ttrain_gen_loss {},\t".format(
                    train_classifier_loss, train_gen_loss))

                # Save sample images
                if e % 1 == 0:
                    final_latent = self.final_latent.predict(
                        fixed_latent, batch_size=batch_size)
                    generated_images = self.generator.predict(
                        final_latent, verbose=0, batch_size=batch_size)
                    img_samples = generated_images / 2. + 0.5  # 从[-1,1]恢复到[0,1]之间的值
                    save_image_array(img_samples,
                                     '{}/plot_epoch_{}.png'.format(
                                         self.res_dir, e),
                                     batch_size=batch_size,
                                     class_num=10)
                if e % 1 == 0:
                    self.backup_point(e, epochs)
            self.trained = True