def data_generator(type='DATA_GEN'): ''' Generate data in specific type. DATA_GEN: generate data with random graphics, use small disc as gap, non-meaningful data DATA_GAP: generate data use small disc as gap on user line-drawings, meaningful data DATA_THIN: directly read offline data generated using normalization(thinning) :param type: DATA_GEN, DATA_GAP, DATA_THIN :return: x_data, y_data ''' # Use both 352 and 176 could achieve better performance gap_configs352 = [[50, 600, 2, 8, 0, 1], [50, 600, 2, 10, 0, 2], [1, 2, 5, 15, 0, 3]] # gap_configs176 = [ # [50, 200, 1, 4, 0, 1], # [50, 200, 1, 5, 0, 2], # [1, 2, 5, 10, 0, 3] # ] # gap_configs128 = [ # [50, 200, 2, 4, 0, 1], # [50, 200, 2, 5, 0, 2], # [1, 2, 5, 15, 0, 3] # ] # gap_configs64 = [ # [50, 200, 1, 4, 0, 1], # [50, 200, 1, 5, 0, 2], # [1, 2, 5, 10, 0, 3] # ] datagen = image.ImageDataGenerator(rescale=1 / 255., rotation_range=180, width_shift_range=0.1, height_shift_range=0.1, zoom_range=0.2, horizontal_flip=True, vertical_flip=True, fill_mode='reflect') if type == 'DATA_GAP': raw_generator_352 = datagen.flow_from_directory( './data/line', target_size=(IMG_HEIGHT, IMG_WIDTH), color_mode='grayscale', seed=SEED, class_mode=None, batch_size=BATCH_SIZE, shuffle=True, interpolation='bilinear') # raw_generator_176 = datagen.flow_from_directory( # './data/line', # target_size=(IMG_HEIGHT // 2, IMG_WIDTH // 2), # color_mode='grayscale', # seed=SEED, # class_mode=None, # batch_size=BATCH_SIZE // 2, # shuffle=True, # interpolation='bilinear' # ) while True: train_y_batch = next(raw_generator_352) train_x_batch, _ = generate_random_gap(train_y_batch, gap_configs352, SEED) yield train_x_batch, train_y_batch elif type == 'DATA_GEN': while True: # Size config is in datagen.py train_y_batch = gen_data(np.random.RandomState(SEED), BATCH_SIZE) train_x_batch, _ = generate_random_gap(train_y_batch, gap_configs352, SEED) yield train_x_batch, train_y_batch elif type == 'DATA_THIN': raw_generator_x = datagen.flow_from_directory('./data/thin', target_size=(IMG_HEIGHT, IMG_WIDTH), color_mode='grayscale', seed=SEED, class_mode=None, batch_size=BATCH_SIZE, shuffle=True, interpolation='bilinear') raw_generator_y = datagen.flow_from_directory('./data/line', target_size=(IMG_HEIGHT, IMG_WIDTH), color_mode='grayscale', seed=SEED, class_mode=None, batch_size=BATCH_SIZE, shuffle=True, interpolation='bilinear') while True: yield next(raw_generator_x), next(raw_generator_y)
def _val_generator(): rnd = np.random.RandomState(SEED + 1) while True: yield gen_data(rnd, BATCH_SIZE)
def data_generator(): rnd = np.random.RandomState(SEED) while True: raw, norm = gen_data(rnd, BATCH_SIZE) yield torch.from_numpy(raw).permute(0, 3, 1, 2), \ torch.from_numpy(norm).permute(0, 3, 1, 2)