def test_normalize_image_from_db_local(db_with_image, tmpdir, db_init): # metadata_id, normalized_storage_location """ Tests that normalize_image_from_db_local loads the image from metadata_db, normalizes, and returns it to new store Output metadata and image are both inspected Future: Feels very scattered. Break into a few tests? """ # Verbose way to get m_id just in case we change the data creation routines. This ensures there's a single record s = create_session() all_records = s.query(ImageRecord).all() s.close() assert len(all_records) == 1 assert all_records[0].file_normalized is None m_id = all_records[0].id normalize_drawing_from_db(m_id, tmpdir) s = create_session() all_records = s.query(ImageRecord).all() s.close() assert len(all_records) == 1 normalized_image_filename = all_records[0].file_normalized assert normalized_image_filename == os.path.join(tmpdir, f"{FAKE_LABEL}_{m_id}.npy") normalized_image = np.load(normalized_image_filename) assert normalized_image.shape == (28, 28)
def db_with_image(dummy_image): m = ImageRecord() m.file_raw = dummy_image m.label = FAKE_LABEL s = create_session() s.add(m) s.commit() s.close() # Not sure why, but if I don't close and reopen the session I get a thread error. s = create_session() m_objs = s.query(ImageRecord).all() print([str(m) for m in m_objs]) s.close()
def find_records_unnormalized() -> List[ImageRecord]: s = create_session() q = s.query(ImageRecord)\ .filter(ImageRecord.file_normalized is not None) results = q.all() s.close() return results
def create_training_data_from_image_db(test_size=0.2, random_state=42): """ Creates a train-test split of the current image db stored in a TrainingData table Args: test_size: random_state: Returns: """ normalized_images = find_records_with_label_normalized() labels = [img.label for img in normalized_images] index_to_label = sorted(set(labels)) label_to_index = {index_to_label[i]: i for i in range(len(index_to_label))} training_images, testing_images = sklearn.model_selection.train_test_split( normalized_images, test_size=test_size, stratify=labels, random_state=random_state, ) td = TrainingData(label_to_index=label_to_index, index_to_label=index_to_label) td.training_images.extend(training_images) td.testing_images.extend(testing_images) s = create_session() s.add(td) s.commit()
def find_records_with_label_normalized(label: str = None) -> List[ImageRecord]: s = create_session() q = s.query(ImageRecord) if label: q = q.filter(ImageRecord.label == label) q = q.filter(ImageRecord.file_normalized.isnot(None)) results = q.all() s.close() return results
def add_record_to_metadata( label: str = None, raw_storage_location: str = None, normalized_storage_location: str = None) -> ImageRecord: m = ImageRecord() m.label = label m.file_raw = raw_storage_location m.file_normalized = normalized_storage_location s = create_session() s.add(m) # commit data to DB but key an unexpired version to pass back to caller # This is essentially a copy of the state of m at the time of commit, but wont it wont automatically refresh it's # data if we interact with it (expire_on_commit=True tells m it needs to resync next time we use it) s.expire_on_commit = False s.commit() s.close() return m
def load_training_data_to_dataframe(training_data_id=None): """ Returns train and test images and labels for a TrainingData entry By default returns the most recent TrainingData entry. Args: training_data_id (TrainingData.id): (OPTIONAL) If specified, returns TrainingData with this id. Otherwise returns the most recently created TrainingData Returns: Tuple of: (list): Training ImageRecords (eg: x_train) (list): Testing ImageRecords (eg: x_test) (list): Training labels (eg: y_train) encoded to integers (list): Testing labels (eg: y_test) encoded to integers (list): List of labels indexed the same as y_train/y_test """ s = create_session() if training_data_id: td: TrainingData = s.query(TrainingData).filter( TrainingData.id == training_data_id).first() if not td: raise ValueError( f"Cannot find TrainingData entry with id = {training_data_id}") else: td: TrainingData = s.query(TrainingData).order_by( TrainingData.created_date.desc()).first() if not td: raise ValueError(f"Cannot find any TrainingData entries") # Build pd.DataFrame objects that are suitable for tf.keras.preprocessing.image.ImageDataGenerator df_train = _image_records_to_dataframe(td.training_images, td.label_to_index) df_test = _image_records_to_dataframe(td.testing_images, td.label_to_index) return df_train, df_test, td.index_to_label
def find_record_by_id(metadata_id: int) -> ImageRecord: s = create_session() m = s.query(ImageRecord).filter(ImageRecord.id == metadata_id).first() s.close() return m
from quick_redraw.data.training_data import TrainingData # from quick_redraw.data.training_data_record import TrainingDataRecord db_file = './training_data_db_inserts.sqlite' print("DELETING OLD TEMP DB") os.remove(db_file) global_init(db_file) image = ImageRecord(label='cat', file_raw='raw.png', file_normalized='norm.png') tdrs = [ # TrainingDataRecord(), # TrainingDataRecord(), # TrainingDataRecord(), ] tdrs[0].image = image index_to_label = ['cat', 'dog'] label_to_index = {'cat': 0, 'dog': 1} td = TrainingData(index_to_label=index_to_label, label_to_index=label_to_index) td.train.extend(tdrs) s = create_session() s.add(td) s.commit()