def train_model(X_train, y_train, X_test, y_test, unbalance, target_classes, output_dir, epochs, dataset_name='CIFAR10'): """ X_train: Imbalanced training X y_train: Imbalanced training y X_test: Test X y_test: Test y unbalance: list The dropping ratios target_classes: list Imbalanced classes chosen corresponding to unbalance ratio output_dir: directory for output model and images epochs: training epochs dataset_name: Dataset name """ print("Executing BAGAN.") # Read command line parameters seed = 0 np.random.seed(seed) gratio_mode = "uniform" dratio_mode = "uniform" adam_lr = 0.00005 opt_class = target_classes batch_size = 128 out_dir = output_dir channels = 3 print('Using dataset: ', dataset_name) # Result directory res_dir = "{}/res_{}_class_{}_ratio_{}_epochs_{}_seed_{}".format( out_dir, dataset_name, target_classes, unbalance, epochs, seed ) if not os.path.exists(res_dir): os.makedirs(res_dir) # Read initial data. print("read input data...") bg_train = BatchGenerator(X_train, y_train, batch_size=batch_size) bg_test = BatchGenerator(X_test, y_test, batch_size=batch_size) print("input data loaded...") shape = bg_train.get_image_shape() #print('shape here:', shape) min_latent_res = shape[-1] while min_latent_res > 8: min_latent_res = min_latent_res / 2 min_latent_res = int(min_latent_res) classes = bg_train.get_label_table() # Initialize statistics information gan_train_losses = defaultdict(list) gan_test_losses = defaultdict(list) img_samples = defaultdict(list) # For all possible minority classes. target_classes = np.array(range(len(classes))) if opt_class is not None: min_classes = np.array(opt_class) else: min_classes = target_classes # Train the model (or reload it if already available if not ( os.path.exists("{}/score.csv".format(res_dir)) and os.path.exists("{}/discriminator.h5".format(res_dir)) and os.path.exists("{}/generator.h5".format(res_dir)) and os.path.exists( "{}/reconstructor.h5".format(res_dir)) ): # Training required print("Required GAN for class {}".format(min_classes)) print('Class counters: ', bg_train.per_class_count) # Train GAN to balance the data gan = bagan.BalancingGAN( target_classes, min_classes, dratio_mode=dratio_mode, gratio_mode=gratio_mode, adam_lr=adam_lr, res_dir=res_dir, image_shape=shape, min_latent_res=min_latent_res ) gan.train(bg_train, bg_test, epochs=epochs) gan.save_history( res_dir, min_classes ) else: # GAN pre-trained # Unbalance the training. print("Loading GAN for class {}".format(min_classes)) gan = bagan.BalancingGAN(target_classes, min_classes, dratio_mode=dratio_mode, gratio_mode=gratio_mode, adam_lr=adam_lr, res_dir=res_dir, image_shape=shape, min_latent_res=min_latent_res) print('Load trained model') gan.load_models( "{}/generator.h5".format( res_dir), "{}/discriminator.h5".format( res_dir), "{}/reconstructor.h5".format( res_dir), bg_train=bg_train # This is required to initialize the per-class mean and covariance matrix ) for i in range(len(min_classes)): # Sample and save images c = min_classes[i] print('saving images for class {}'.format(c)) sample_size = math.ceil(5000*unbalance[i]) img_samples['class_{}'.format(c)] = gan.generate_samples( c=c, samples=sample_size) #save_image_array(np.array([img_samples['class_{}'.format(c)]]), '{}/plot_class_{}.png'.format(res_dir, c)) #plt.imshow(np.array([img_samples['class_{}'.format(c)]])[0][0]) save_image_files(np.array([img_samples['class_{}'.format(c)]])[ 0], c, unbalance[i],res_dir, dataset_name)
res_dir, c))): # Skip GAN training if results are available print("Required GAN for class {}".format(c)) # Unbalance the training. bg_train_partial = BatchGenerator(BatchGenerator.TRAIN, batch_size, class_to_prune=c, unbalance=unbalance) print('Class counters: ', bg_train_partial.per_class_count) # Train GAN to balance the data gan = bagan.BalancingGAN(target_classes, c, dratio_mode=dratio_mode, gratio_mode=gratio_mode, adam_lr=adam_lr, res_dir=res_dir, image_shape=shape, min_latent_res=min_latent_res) gan.train(bg_train_partial, bg_test, epochs=gan_epochs) gan.save_history(res_dir, c) else: # GAN pre-trained # Unbalance the training. print("Loading GAN for class {}".format(c)) bg_train_partial = BatchGenerator(BatchGenerator.TRAIN, batch_size, class_to_prune=c, unbalance=unbalance) gan = bagan.BalancingGAN(target_classes,
def generate_samples(X_train, y_train, unbalance, target_classes, input_dir, epochs, dataset_name='CIFAR10', shape=[3, 32, 32]): """ X_train: Imbalanced training X y_train: Imbalanced training y unbalance: list The dropping ratios target_classes: list Imbalanced classes chosen corresponding to unbalance ratio in_dir: directory for input model and images(same rule as the training function) epochs: training epochs dataset_name: Dataset name """ print("Executing BAGAN.") # Read command line parameters seed = 0 np.random.seed(seed) gratio_mode = "uniform" dratio_mode = "uniform" adam_lr = 0.00005 opt_class = target_classes batch_size = 128 in_dir = input_dir channels = 3 # Result directory res_dir = "{}/res_{}_class_{}_ratio_{}_epochs_{}_seed_{}".format( in_dir, dataset_name, target_classes, unbalance, epochs, seed) if not os.path.exists(res_dir): raise FileExistsError("Input directory doesn't exist") # Read initial data. print("read input data...") bg_train = BatchGenerator(X_train, y_train, batch_size=batch_size) print("input data loaded...") shape = bg_train.get_image_shape() #print('shape here:', shape) min_latent_res = shape[-1] while min_latent_res > 8: min_latent_res = min_latent_res / 2 min_latent_res = int(min_latent_res) classes = bg_train.get_label_table() img_samples = defaultdict(list) # For all possible minority classes. target_classes = np.array(range(len(classes))) if opt_class is not None: min_classes = np.array(opt_class) else: min_classes = target_classes # Train the model (or reload it if already available if (os.path.exists("{}/score.csv".format(res_dir)) and os.path.exists("{}/discriminator.h5".format(res_dir)) and os.path.exists("{}/generator.h5".format(res_dir)) and os.path.exists("{}/reconstructor.h5".format(res_dir))): print("Loading GAN for class {}".format(min_classes)) gan = bagan.BalancingGAN(target_classes, min_classes, dratio_mode=dratio_mode, gratio_mode=gratio_mode, adam_lr=adam_lr, res_dir=res_dir, image_shape=shape, min_latent_res=min_latent_res) print('Load trained model') gan.load_models( "{}/generator.h5".format(res_dir), "{}/discriminator.h5".format(res_dir), "{}/reconstructor.h5".format(res_dir), bg_train= bg_train # This is required to initialize the per-class mean and covariance matrix ) else: raise FileExistsError("Trained model doesn't exist") for i in range(len(min_classes)): # Sample and save images c = min_classes[i] print('generating images for class {}'.format(c)) sample_size = 5000 - np.sum(y_train == c) sim_images = gan.generate_samples(c=c, samples=sample_size) sim_images = np.transpose(sim_images, axes=(0, 2, 3, 1)) sim_images = (sim_images + 1) / 2 X_train = np.concatenate((X_train, sim_images)) y_train = np.concatenate((y_train, c * np.ones(sample_size))) return X_train, y_train