def test_create_tif(ground_truth_raster, tmp_path): src = rasterio.open(ground_truth_raster) data = src.read() filename = "{}/test.tif".format(tmp_path) resample.create_tif(source_tif=ground_truth_raster, filename=filename, numpy_array=data) #Assert the original and saved are identical saved_src = rasterio.open(filename) saved_data = saved_src.read() assert saved_src.shape == src.shape np.testing.assert_array_almost_equal(data, saved_data)
def test_resample(ground_truth_raster, training_raster, tmp_path): """Assert that resampled training raster has the shape of the ground truth raster""" src = rasterio.open(training_raster) data = src.read(1) data = np.expand_dims(data,0) filename = "{}/test.tif".format(tmp_path) resample.create_tif(source_tif=training_raster, filename=filename, numpy_array=data) resampled_filename = resample.resample(filename) resampled_src = rasterio.open(resampled_filename) ground_truth_raster_src = rasterio.open(ground_truth_raster) resampled_data = resampled_src.read() assert ground_truth_raster_src.shape == resampled_src.shape assert ground_truth_raster_src.bounds == resampled_src.bounds
#Predict predict_tfrecords = glob.glob( "/orange/ewhite/b.weinstein/Houston2018/tfrecords/predict/*.tfrecord") results = model.predict_raster(predict_tfrecords, batch_size=512) #predicted classes print(results.label.unique()) predicted_raster = visualize.create_raster(results) print(np.unique(predicted_raster)) experiment.log_image(name="Prediction", image_data=predicted_raster, image_colormap=visualize.discrete_cmap(20, base_cmap="jet")) #Save as tif for resampling prediction_path = os.path.join(save_dir, "prediction.tif") predicted_raster = np.expand_dims(predicted_raster, 0) resample.create_tif( "/home/b.weinstein/DeepTreeAttention/data/processed/20170218_UH_CASI_S4_NAD83.tif", filename=prediction_path, numpy_array=predicted_raster) filename = resample.resample(prediction_path) experiment.log_image(name="Resampled Prediction", image_data=filename, image_colormap=visualize.discrete_cmap(20, base_cmap="jet")) #Save model model.model.save("{}/{}.h5".format(save_dir, timestamp))