def test_download_if_needed_local(self):
        with self.assertRaises(NotReadableError):
            file_to_str(self.local_path)

        str_to_file(self.content_str, self.local_path)
        upload_or_copy(self.local_path, self.local_path)
        local_path = download_if_needed(self.local_path, self.tmp_dir.name)
        self.assertEqual(local_path, self.local_path)
    def test_file_to_str_local(self):
        str_to_file(self.content_str, self.local_path)
        content_str = file_to_str(self.local_path)
        self.assertEqual(self.content_str, content_str)

        wrong_path = '/wrongpath/x.txt'
        with self.assertRaises(NotReadableError):
            file_to_str(wrong_path)
    def test_file_to_str_s3(self):
        wrong_path = 's3://wrongpath/x.txt'

        with self.assertRaises(NotWritableError):
            str_to_file(self.content_str, wrong_path)

        str_to_file(self.content_str, self.s3_path)
        content_str = file_to_str(self.s3_path)
        self.assertEqual(self.content_str, content_str)

        with self.assertRaises(NotReadableError):
            file_to_str(wrong_path)
    def test_download_if_needed_s3(self):
        with self.assertRaises(NotReadableError):
            file_to_str(self.s3_path)

        str_to_file(self.content_str, self.local_path)
        upload_or_copy(self.local_path, self.s3_path)
        local_path = download_if_needed(self.s3_path, self.tmp_dir.name)
        content_str = file_to_str(local_path)
        self.assertEqual(self.content_str, content_str)

        wrong_path = 's3://wrongpath/x.txt'
        with self.assertRaises(NotWritableError):
            upload_or_copy(local_path, wrong_path)
示例#5
0
def plot_training_logs(log_uris, labels):
    for log_uri, label in zip(log_uris, labels):
        log_df = pd.read_csv(StringIO(file_to_str(log_uri)))
        epoch = log_df['epoch'].to_numpy()
        building_f1 = log_df['building_f1'].to_numpy()
        plt.plot(epoch, building_f1, label=label)
        plt.xlabel('epoch')
        plt.ylabel('building f1')
        plt.title('Trained on 1%')
        plt.legend()
    def _get_geojson(self):
        geojson = json.loads(file_to_str(self.vs_config.uri))
        if not self.vs_config.ignore_crs_field and 'crs' in geojson:
            raise Exception((
                'The GeoJSON file at {} contains a CRS field which is not '
                'allowed by the current GeoJSON standard or by Raster Vision. '
                'All coordinates are expected to be in EPSG:4326 CRS. If the file uses '
                'EPSG:4326 (ie. lat/lng on the WGS84 reference ellipsoid) and you would '
                'like to ignore the CRS field, set ignore_crs_field=True in '
                'GeoJSONVectorSourceConfig.').format(self.vs_config.uri))

        return self.class_inference.transform_geojson(geojson)
示例#7
0
def filter_geojson(labels_uri, output_uri, class_names):
    """Remove features that aren't in class_names and remove class_ids."""
    labels_str = file_to_str(labels_uri)
    labels = json.loads(labels_str)
    filtered_features = []

    for feature in labels['features']:
        feature = copy.deepcopy(feature)
        properties = feature.get('properties')
        if properties:
            class_name = properties.get('class_name') or properties('label')
            if class_name in class_names:
                del properties['class_id']
                filtered_features.append(feature)

    new_labels = {'features': filtered_features}
    str_to_file(json.dumps(new_labels), output_uri)
示例#8
0
def get_scene_info(csv_uri):
    csv_str = file_to_str(csv_uri)
    reader = csv.reader(StringIO(csv_str), delimiter=',')
    return list(reader)
示例#9
0
def get_acc(metrics_uri):
    metric_str = file_to_str(metrics_uri).split('\n')[-2]
    metrics_dict = json.loads(metric_str)
    acc = metrics_dict['test_accuracy_list_meter']['top_1']['res5avg']
    return acc
示例#10
0
 def load(stats_uri):
     stats_json = json.loads(file_to_str(stats_uri))
     stats = RasterStats()
     stats.means = stats_json['means']
     stats.stds = stats_json['stds']
     return stats
示例#11
0
 def print_messages(self):
     for message_uri in self.config.message_uris:
         message = file_to_str(message_uri)
         print(message)
示例#12
0
 def print_messages(self):
     # Read all the message files and print them.
     for message_uri in self.config.message_uris:
         message = file_to_str(message_uri)
         print(message)