def create_data_generator(data_dir, other_classes, batch_size, class_mode): """ Creates a 2-class data generator for real and fake face images. Args: data_dir: Directory containing class directories. other_classes: Collection of classes other than "real". The class "real" will be included in the generator by default. Order matters when `class_mode` is not "binary". batch_size: Number of images to process at time. class_mode: See `keras.preprocessing.image.ImageDataGenerator.flow_from_directory`. Returns: A DirectoryIterator and a dictionary of class weights. The class weights map class indices to the inverse of their sample count. For instance, if their were 100 images belonging the first fake class, 50 to second fake class, and 20 to the real class, the weights would be { 0 : 0.01, 1 : 0.02, 2 : 0.05 } The classes with fewer images are weighted more heavily. These weights can help combat class sample imbalances during training. """ # Initialize generator. classes = list(other_classes) + ['real'] generator = ImageDataGenerator(rescale=1/255).flow_from_directory( data_dir, classes=classes, target_size=IMG_SIZE, batch_size=batch_size, class_mode=class_mode, subset='training') if class_mode == 'binary': # Modify data labels. real_index = generator.class_indices['real'] new_classes = [1 if i == real_index else 0 for i in generator.classes] generator.classes = np.array(new_classes, dtype=np.int32) # Change class-to-index mapping. new_indices_map = { 'fake' : 0, 'real' : 1 } generator.class_indices = new_indices_map # Calculate the weights. _, counts = np.unique(generator.classes, return_counts=True) weights = {} for i, count in enumerate(counts): weights[i] = 1 / count return generator, weights
def load_single_class_generators(data_dir, classes, batch_size=16): """ Creates a dictionary of data generators for a list of classes. Args: data_dir: Directory containing classes. classes: List of classes to make generators for. Should have corresponding directories within the directory pointed to by `data_dir`. batch_size: Number of images to read at a time. Returns: Dictionary mapping class names to data generators. """ generators = {} for c in classes: path = os.path.join(data_dir, c) if not os.path.isdir(path): print('ERROR: No directory for class "{}" in "{}"'.format(c, data_dir), file=stderr) exit(0) gen = ImageDataGenerator(rescale=1/255).flow_from_directory( data_dir, classes=[c], target_size=IMG_SIZE, batch_size=batch_size, class_mode='binary', subset='training') # Real images need to be labeled "1" and not "0". if c == 'real': # Modify data labels. new_classes = np.ones(gen.classes.shape, dtype=gen.classes.dtype) gen.classes = np.array(new_classes, dtype=np.int32) # Change class-to-index mapping. new_indices_map = {'real' : 1} gen.class_indices = new_indices_map generators[c] = gen return generators