def test_download_if_needed_local(self): with self.assertRaises(NotReadableError): file_to_str(self.local_path) str_to_file(self.content_str, self.local_path) upload_or_copy(self.local_path, self.local_path) local_path = download_if_needed(self.local_path, self.tmp_dir.name) self.assertEqual(local_path, self.local_path)
def test_file_to_str_local(self): str_to_file(self.content_str, self.local_path) content_str = file_to_str(self.local_path) self.assertEqual(self.content_str, content_str) wrong_path = '/wrongpath/x.txt' with self.assertRaises(NotReadableError): file_to_str(wrong_path)
def test_file_to_str_s3(self): wrong_path = 's3://wrongpath/x.txt' with self.assertRaises(NotWritableError): str_to_file(self.content_str, wrong_path) str_to_file(self.content_str, self.s3_path) content_str = file_to_str(self.s3_path) self.assertEqual(self.content_str, content_str) with self.assertRaises(NotReadableError): file_to_str(wrong_path)
def test_download_if_needed_s3(self): with self.assertRaises(NotReadableError): file_to_str(self.s3_path) str_to_file(self.content_str, self.local_path) upload_or_copy(self.local_path, self.s3_path) local_path = download_if_needed(self.s3_path, self.tmp_dir.name) content_str = file_to_str(local_path) self.assertEqual(self.content_str, content_str) wrong_path = 's3://wrongpath/x.txt' with self.assertRaises(NotWritableError): upload_or_copy(local_path, wrong_path)
def plot_training_logs(log_uris, labels): for log_uri, label in zip(log_uris, labels): log_df = pd.read_csv(StringIO(file_to_str(log_uri))) epoch = log_df['epoch'].to_numpy() building_f1 = log_df['building_f1'].to_numpy() plt.plot(epoch, building_f1, label=label) plt.xlabel('epoch') plt.ylabel('building f1') plt.title('Trained on 1%') plt.legend()
def _get_geojson(self): geojson = json.loads(file_to_str(self.vs_config.uri)) if not self.vs_config.ignore_crs_field and 'crs' in geojson: raise Exception(( 'The GeoJSON file at {} contains a CRS field which is not ' 'allowed by the current GeoJSON standard or by Raster Vision. ' 'All coordinates are expected to be in EPSG:4326 CRS. If the file uses ' 'EPSG:4326 (ie. lat/lng on the WGS84 reference ellipsoid) and you would ' 'like to ignore the CRS field, set ignore_crs_field=True in ' 'GeoJSONVectorSourceConfig.').format(self.vs_config.uri)) return self.class_inference.transform_geojson(geojson)
def filter_geojson(labels_uri, output_uri, class_names): """Remove features that aren't in class_names and remove class_ids.""" labels_str = file_to_str(labels_uri) labels = json.loads(labels_str) filtered_features = [] for feature in labels['features']: feature = copy.deepcopy(feature) properties = feature.get('properties') if properties: class_name = properties.get('class_name') or properties('label') if class_name in class_names: del properties['class_id'] filtered_features.append(feature) new_labels = {'features': filtered_features} str_to_file(json.dumps(new_labels), output_uri)
def get_scene_info(csv_uri): csv_str = file_to_str(csv_uri) reader = csv.reader(StringIO(csv_str), delimiter=',') return list(reader)
def get_acc(metrics_uri): metric_str = file_to_str(metrics_uri).split('\n')[-2] metrics_dict = json.loads(metric_str) acc = metrics_dict['test_accuracy_list_meter']['top_1']['res5avg'] return acc
def load(stats_uri): stats_json = json.loads(file_to_str(stats_uri)) stats = RasterStats() stats.means = stats_json['means'] stats.stds = stats_json['stds'] return stats
def print_messages(self): for message_uri in self.config.message_uris: message = file_to_str(message_uri) print(message)
def print_messages(self): # Read all the message files and print them. for message_uri in self.config.message_uris: message = file_to_str(message_uri) print(message)