from tensorflow.keras.callbacks import ModelCheckpoint root_path = os.path.abspath(os.path.join('..')) if root_path not in sys.path: sys.path.append(root_path) import config from qr_codes_loader import load_qr_codes, build_dataset from ssd_utils.ssd_loss import SSDLoss from utils import import_by_name, train_test_split_tensors, MeanAveragePrecisionCallback # Load train and validation data trainval_image_paths, trainval_bnd_boxes = load_qr_codes(split='train') (train_image_paths, valid_image_paths, train_bnd_boxes, valid_bnd_boxes) = train_test_split_tensors(trainval_image_paths, trainval_bnd_boxes, test_size=config.VALID_SAMPLES, random_state=config.RANDOM_SEED) valid_data = build_dataset(valid_image_paths, valid_bnd_boxes, image_size=config.IMAGE_SIZE, batch_size=config.BATCH_SIZE) for run in range(1, config.NUM_RUNS + 1): weights_dir = 'weights_{}'.format(run) history_dir = 'history_{}'.format(run) os.makedirs(weights_dir, exist_ok=True) os.makedirs(history_dir, exist_ok=True) for architecture in config.ARCHITECTURES:
history_file = model_name + '_history.pickle' history_path = os.path.join(history_dir, history_file) with open(history_path, 'wb') as f: pickle.dump(history.history, f) model.load_weights(model_file) model.save_weights(model_path) os.remove(model_file) del model, fake_data for train_samples in config.TRAIN_SAMPLES: (_, small_train_image_paths, _, small_train_bnd_boxes) = train_test_split_tensors( train_image_paths, train_bnd_boxes, test_size=train_samples, random_state=config.RANDOM_SEED) assert small_train_image_paths.shape[ 0] == small_train_bnd_boxes.shape[0] assert small_train_image_paths.shape[0] == train_samples for train_type in ['from_scratch', 'finetuned']: model_name = architecture.lower() + '_{}_samples_{}'.format( train_samples, train_type) model_file = model_name + '.h5' model_path = os.path.join(weights_dir, model_file) if not os.path.exists(model_path): print('\n\nINFO: Training {} {} on {} samples'.format(