def viz_img_from_z_dist(self, z_dist, epoch): n_classes = self.opts['n_classes'] y = np.repeat(np.arange(0, n_classes).reshape(n_classes, 1), 10) z = self.sample_pz(len(y), z_dist, y) sample_gen = self.sess.run( self.decoded, feed_dict={self.sample_noise: z, self.is_training: False}) sample_gen = sample_gen.reshape(self.opts['n_classes'], sample_gen.shape[0] // self.opts['n_classes'], sample_gen.shape[-3], sample_gen.shape[-2], sample_gen.shape[-1]) utils.save_image_array(sample_gen, self.opts['work_dir'] + os.sep + "img_from_z_dist_{}.png".format(epoch))
def train(self, bg_train, bg_test, epochs=50): if not self.trained: self.autoenc_epochs = epochs # Class actual ratio self.class_aratio = bg_train.get_class_probability() # Class balancing ratio self._set_class_ratios() print("uratio set to: {}".format(self.class_uratio)) print("dratio set to: {}".format(self.class_dratio)) print("gratio set to: {}".format(self.class_gratio)) # Initialization print("BAGAN init_autoenc") self.init_autoenc(bg_train) print("BAGAN autoenc initialized, init gan") start_e = self.init_gan() print("BAGAN gan initialized, start_e: ", start_e) crt_c = 0 act_img_samples = bg_train.get_samples_for_class(crt_c, 10) img_samples = np.array([[ act_img_samples, self.generator.predict( self.reconstructor.predict(act_img_samples)), self.generate_samples(crt_c, 10, bg_train) ]]) for crt_c in range(1, self.nclasses): act_img_samples = bg_train.get_samples_for_class(crt_c, 10) new_samples = np.array([[ act_img_samples, self.generator.predict( self.reconstructor.predict(act_img_samples)), self.generate_samples(crt_c, 10, bg_train) ]]) img_samples = np.concatenate((img_samples, new_samples), axis=0) shape = img_samples.shape img_samples = img_samples.reshape( (-1, shape[-4], shape[-3], shape[-2], shape[-1])) save_image_array( img_samples, '{}/cmp_class_{}_init.png'.format(self.res_dir, self.target_class_id)) # Train for e in range(start_e, epochs): print('Epoch {} of {}'.format(self.dratio_mode, self.gratio_mode, e + 1, epochs)) # train_disc_loss, train_gen_loss = self._train_one_epoch(copy.deepcopy(bg_train)) train_disc_loss, train_gen_loss = self._train_one_epoch( bg_train) # Test: # generate a new batch of noise nb_test = bg_test.get_num_samples() fake_size = int(np.ceil(nb_test * 1.0 / self.nclasses)) sampled_labels = self._biased_sample_labels(nb_test, "d") latent_gen = self.generate_latent(sampled_labels, bg_test) # sample some labels from p_c and generate images from them generated_images = self.generator.predict(latent_gen, verbose=False) X = np.concatenate((bg_test.dataset_x, generated_images)) aux_y = np.concatenate( (bg_test.dataset_y, np.full(len(sampled_labels), self.nclasses)), axis=0) # see if the discriminator can figure itself out... test_disc_loss = self.discriminator.evaluate(X, aux_y, verbose=False) # make new latent sampled_labels = self._biased_sample_labels( fake_size + nb_test, "g") latent_gen = self.generate_latent(sampled_labels, bg_test) test_gen_loss = self.combined.evaluate(latent_gen, sampled_labels, verbose=False) # generate an epoch report on performance self.train_history['disc_loss'].append(train_disc_loss) self.train_history['gen_loss'].append(train_gen_loss) self.test_history['disc_loss'].append(test_disc_loss) self.test_history['gen_loss'].append(test_gen_loss) print( "train_disc_loss {},\ttrain_gen_loss {},\ttest_disc_loss {},\ttest_gen_loss {}" .format(train_disc_loss, train_gen_loss, test_disc_loss, test_gen_loss)) # Save sample images if e % 10 == 9: img_samples = np.array([ self.generate_samples(c, 10, bg_train) for c in range(0, self.nclasses) ]) save_image_array( img_samples, '{}/plot_class_{}_epoch_{}.png'.format( self.res_dir, self.target_class_id, e)) # Generate whole evaluation plot (real img, autoencoded img, fake img) if e % 10 == 5: self.backup_point(e) crt_c = 0 act_img_samples = bg_train.get_samples_for_class(crt_c, 10) img_samples = np.array([[ act_img_samples, self.generator.predict( self.reconstructor.predict(act_img_samples)), self.generate_samples(crt_c, 10, bg_train) ]]) for crt_c in range(1, self.nclasses): act_img_samples = bg_train.get_samples_for_class( crt_c, 10) new_samples = np.array([[ act_img_samples, self.generator.predict( self.reconstructor.predict(act_img_samples)), self.generate_samples(crt_c, 10, bg_train) ]]) img_samples = np.concatenate( (img_samples, new_samples), axis=0) shape = img_samples.shape img_samples = img_samples.reshape( (-1, shape[-4], shape[-3], shape[-2], shape[-1])) save_image_array( img_samples, '{}/cmp_class_{}_epoch_{}.png'.format( self.res_dir, self.target_class_id, e)) self.trained = True
else: # GAN pre-trained # Unbalance the training. print("Loading GAN for class {}".format(c)) bg_train_partial = BatchGenerator(BatchGenerator.TRAIN, batch_size, class_to_prune=c, unbalance=unbalance) gan = bagan.BalancingGAN(target_classes, c, dratio_mode=dratio_mode, gratio_mode=gratio_mode, adam_lr=adam_lr, res_dir=res_dir, image_shape=shape, min_latent_res=min_latent_res) gan.load_models( "{}/class_{}_generator.h5".format(res_dir, c), "{}/class_{}_discriminator.h5".format(res_dir, c), "{}/class_{}_reconstructor.h5".format(res_dir, c), bg_train= bg_train_partial # This is required to initialize the per-class mean and covariance matrix ) # Sample and save images img_samples['class_{}'.format(c)] = gan.generate_samples(c=c, samples=10) save_image_array(np.array([img_samples['class_{}'.format(c)]]), '{}/plot_class_{}.png'.format(res_dir, c))
def train(self, bg_train, bg_test, epochs=100, class_num=10, latent_size=100, mode_z='uniform', batch_size=100, gen_class_ration=[]): if not self.trained: # Class actual ratio self.class_aratio = bg_train.get_class_probability() fixed_latent = self.generate_latent(batch_size, latent_size, mode_z) # Train start_e = 0 for e in range(start_e, epochs): start_time = time() # Train print('GAN train epoch: {}/{}'.format(e, epochs)) train_classifier_loss, train_gen_loss = self._train_one_epoch( bg_train, class_num, batch_size=batch_size, mode_z=mode_z, gen_class_ration=gen_class_ration) loss_R = train_gen_loss[0] - train_gen_loss[ 1] - train_gen_loss[2] self.result_logger.add_training_metrics1( float(train_gen_loss[0]), float(train_gen_loss[1]), float(train_gen_loss[2]), float(loss_R), float(train_classifier_loss[0]), float(train_classifier_loss[1]), float(train_classifier_loss[2]), time() - start_time) # Test # test_loss = self.classifier.evaluate( bg_test.dataset_x, [bg_test.dataset_y, bg_test.dataset_y], verbose=False) self.result_logger.add_testing_metrics(test_loss[0], test_loss[1], test_loss[2]) probs_0, probs_1 = self.classifier.predict( bg_test.dataset_x, batch_size=batch_size, verbose=True) final_probs = probs_1 predicts = np.argmax(final_probs, axis=-1) self.result_logger.save_prediction(e, bg_test.dataset_y, predicts, probs_0, probs_1, epochs=epochs) self.result_logger.save_metrics() print("train_classifier_loss {},\ttrain_gen_loss {},\t".format( train_classifier_loss, train_gen_loss)) # Save sample images if e % 1 == 0: final_latent = self.final_latent.predict( fixed_latent, batch_size=batch_size) generated_images = self.generator.predict( final_latent, verbose=0, batch_size=batch_size) img_samples = generated_images / 2. + 0.5 # 从[-1,1]恢复到[0,1]之间的值 save_image_array(img_samples, '{}/plot_epoch_{}.png'.format( self.res_dir, e), batch_size=batch_size, class_num=10) if e % 1 == 0: self.backup_point(e, epochs) self.trained = True