def predict(data_path, model_str, ckpt_path_=''): #add models here: if model_str == 'unet': from designs import unet as design if model_str == 'unet64filters': from designs import unet64filters as design if model_str == 'unet64filters_weighted': from designs import unet64filters as design if model_str == 'flatunet': from designs import flatunet as design if model_str == 'unet64batchnorm': from designs import unet64batchnorm as design if model_str == 'u64_onehot': from designs import u64_onehot as design if model_str == 'nomaxpool': from designs import u64_onehot as design imgs_train, imgs_mask_train = load_train_data(data_path) print('Creating and compiling model...') model = design.build() print('Loading and preprocessing test data...') imgs_test, imgs_id_test = load_test_data(data_path) imgs_test = preprocess(imgs_test) imgs_test = normalize(imgs_test) if ckpt_path_ == '': ckpt_path = os.path.join(data_path, 'internal/checkpoints') else: ckpt_path = ckpt_path_ ckpt_file = os.path.join(ckpt_path, 'weights_' + model_str + '.h5') model.load_weights(ckpt_file) print('Loading saved weights :', ckpt_file) print('Predicting masks on test data...') imgs_mask_test = model.predict(imgs_test, verbose=1) out_path = os.path.join(data_path, 'output') out_file = os.path.join(data_path, 'internal/npy/imgs_mask_test.npy') if ckpt_path_ != '': if not os.path.exists(ckpt_path + '/predictions/' + model_str + '/'): os.mkdir(ckpt_path + '/predictions/' + model_str + '/') out_path = ckpt_path + '/predictions/' out_file = ckpt_path + '/predictions/' + model_str + '/imgs_mask_test.npy' np.save(out_file, imgs_mask_test) print('Saving predicted masks to files to:', out_path + '/' + model_str) if not os.path.exists(out_path + '/' + model_str): os.mkdir(out_path + '/' + model_str) for image, image_id in zip(imgs_mask_test, imgs_id_test): image = (image[:, :, 0] * 255.).astype(np.uint8) imsave( os.path.join(out_path + '/' + model_str, '{0:0>5}_pred.png'.format(image_id)), image)
def train(data_path, model_str, number_of_epochs=2, batch_size=10, test_data_fraction=0.15, checkpoint_period=10, load_prev_weights=False, early_stop_patience=10): #add models here: if model_str == 'unet': from designs import unet as design if model_str == 'unet64filters': from designs import unet64filters as design if model_str == 'unet64filters_weighted': from designs import unet64filters_weighted as design if model_str == 'flatunet': from designs import flatunet as design if model_str == 'unet64batchnorm': from designs import unet64batchnorm as design if model_str == 'u64_onehot': from designs import u64_onehot as design if model_str == 'nomaxpool': from designs import nomaxpool as design # DATA LOADING AND PREPROCESSING print('Loading and preprocessing train data...') # load input images imgs_train, imgs_mask_train = load_train_data(data_path) imgs_ew = load_ew_data(data_path) imgs_train = preprocess(imgs_train) imgs_mask_train = preprocess(imgs_mask_train) imgs_ew = preprocess(imgs_ew) imgs_train = normalize(imgs_train) imgs_mask_train = normalize_mask(imgs_mask_train) imgs_ew = normalize_errorweight(imgs_ew) if 'weighted' in model_str: mask_combined = [] for i in range(len(imgs_mask_train)): mask_combined.append( np.concatenate([imgs_mask_train[i], imgs_ew[i]], axis=2)) print(mask_combined[0]) # BUILD MODEL print('Creating and compiling model...') model = design.build() #print layout of model: #model.summary() # set up saving weights at checkpoints, if not os.path.exists(data_path + '/internal/checkpoints'): os.makedirs(data_path + '/internal/checkpoints') ckpt_dir = os.path.join(data_path, 'internal/checkpoints') ckpt_file = os.path.join(ckpt_dir, 'weights_' + model_str + '_epoch{epoch:02d}.h5') model_checkpoint = ModelCheckpoint(ckpt_file, monitor='val_loss', save_best_only=True, save_weights_only=True, period=checkpoint_period) # save epoch logs to txt CSV_LOG_FILENAME = os.path.join(ckpt_dir, 'log_' + model_str + '.csv') csv_logger = CSVLogger(CSV_LOG_FILENAME) # Early stopping early_stop = EarlyStopping(monitor='val_loss', patience=early_stop_patience, verbose=1) if (load_prev_weights): try: model.load_weights(data_path + '/internal/prev_checkpoints_to_load/weights_' + model_str + '.h5') print( 'Loading prev weights:', data_path + '/internal/prev_checkpoints_to_load/weights_' + model_str + '.h5') except Exception as e: print('Problem loading the saved weight!') print(e) return # FIT MODEL print('Fitting model...') model_out = model.fit(imgs_train, imgs_mask_train, batch_size=batch_size, epochs=number_of_epochs, verbose=1, shuffle=True, validation_split=test_data_fraction, callbacks=[model_checkpoint, csv_logger, early_stop]) model.save(ckpt_dir + '/weights_' + model_str + '.h5') #save model and final weights. return model_out