Esempio n. 1
0
    def test_http(self):
        download_dir = '/download_dir'
        uri = 'http://bucket/my/file.txt'
        path = get_local_path(uri, download_dir)
        self.assertEqual(path, '/download_dir/http/bucket/my/file.txt')

        # simulate a zxy tile URI
        uri = 'http://bucket/10/25/53?auth=426753'
        path = get_local_path(uri, download_dir)
        self.assertEqual(path, '/download_dir/http/bucket/10/25/53')
Esempio n. 2
0
    def unzip_data(self) -> List[str]:
        """Unzip dataset zip files.

        Returns:
            paths to directories that each contain contents of one zip file
        """
        cfg = self.cfg
        if cfg.data.uri.startswith('s3://') or cfg.data.uri.startswith('/'):
            data_uri = cfg.data.uri
        else:
            data_uri = join(cfg.base_uri, cfg.data.uri)

        data_dirs = []
        zip_uris = [data_uri] if data_uri.endswith('.zip') else list_paths(
            data_uri, 'zip')
        for zip_ind, zip_uri in enumerate(zip_uris):
            zip_path = get_local_path(zip_uri, self.data_cache_dir)
            if not isfile(zip_path):
                zip_path = download_if_needed(zip_uri, self.data_cache_dir)
            with zipfile.ZipFile(zip_path, 'r') as zipf:
                data_dir = join(self.tmp_dir, 'data', str(zip_ind))
                data_dirs.append(data_dir)
                zipf.extractall(data_dir)

        return data_dirs
Esempio n. 3
0
    def unzip_data(self, uri: Union[str, List[str]]) -> List[str]:
        """Unzip dataset zip files.

        Args:
            uri: a list of URIs of zip files or the URI of a directory containing
                zip files

        Returns:
            paths to directories that each contain contents of one zip file
        """
        cfg = self.cfg
        data_dirs = []

        if isinstance(uri, list):
            zip_uris = uri
        else:
            # TODO generalize this to work with any file system
            if uri.startswith('s3://') or uri.startswith('/'):
                data_uri = uri
            else:
                data_uri = join(cfg.base_uri, uri)
            zip_uris = ([data_uri] if data_uri.endswith('.zip') else
                        list_paths(data_uri, 'zip'))

        for zip_ind, zip_uri in enumerate(zip_uris):
            zip_path = get_local_path(zip_uri, self.data_cache_dir)
            if not isfile(zip_path):
                zip_path = download_if_needed(zip_uri, self.data_cache_dir)
            with zipfile.ZipFile(zip_path, 'r') as zipf:
                data_dir = join(self.tmp_dir, 'data', str(uuid.uuid4()),
                                str(zip_ind))
                data_dirs.append(data_dir)
                zipf.extractall(data_dir)

        return data_dirs
Esempio n. 4
0
def crop_image(image_uri, window, crop_uri):
    im_dataset = rasterio.open(image_uri)
    rasterio_window = window.rasterio_format()
    im = im_dataset.read(window=rasterio_window)

    with tempfile.TemporaryDirectory() as tmp_dir:
        crop_path = get_local_path(crop_uri, tmp_dir)
        make_dir(crop_path, use_dirname=True)

        meta = im_dataset.meta
        meta['width'], meta['height'] = window.get_width(), window.get_height()
        meta['transform'] = rasterio.windows.transform(rasterio_window,
                                                       im_dataset.transform)

        with rasterio.open(crop_path, 'w', **meta) as dst:
            dst.colorinterp = im_dataset.colorinterp
            dst.write(im)

        upload_or_copy(crop_path, crop_uri)
Esempio n. 5
0
    def __init__(self,
                 cfg: LearnerConfig,
                 tmp_dir: str,
                 model_path: Optional[str] = None):
        """Constructor.

        Args:
            cfg: configuration
            tmp_dir: root of temp dirs
            model_path: a local path to model weights. If provided, the model is loaded
                and it is assumed that this Learner will be used for prediction only.
        """
        self.cfg = cfg
        self.tmp_dir = tmp_dir

        # TODO make cache dirs configurable
        torch_cache_dir = '/opt/data/torch-cache'
        os.environ['TORCH_HOME'] = torch_cache_dir
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.data_cache_dir = '/opt/data/data-cache'
        make_dir(self.data_cache_dir)

        self.model = self.build_model()
        self.model.to(self.device)

        if model_path is not None:
            if isfile(model_path):
                self.model.load_state_dict(
                    torch.load(model_path, map_location=self.device))
            else:
                raise Exception(
                    'Model could not be found at {}'.format(model_path))
            self.model.eval()
        else:
            log.info(self.cfg)

            # ds = dataset, dl = dataloader
            self.train_ds = None
            self.train_dl = None
            self.valid_ds = None
            self.valid_dl = None
            self.test_ds = None
            self.test_dl = None

            if cfg.output_uri.startswith('s3://'):
                self.output_dir = get_local_path(cfg.output_uri, tmp_dir)
                make_dir(self.output_dir, force_empty=True)
                if not cfg.overfit_mode:
                    self.sync_from_cloud()
            else:
                self.output_dir = cfg.output_uri
                make_dir(self.output_dir)

            self.last_model_path = join(self.output_dir, 'last-model.pth')
            self.config_path = join(self.output_dir, 'config.json')
            self.train_state_path = join(self.output_dir, 'train-state.json')
            self.log_path = join(self.output_dir, 'log.csv')
            model_bundle_fn = basename(cfg.get_model_bundle_uri())
            self.model_bundle_path = join(self.output_dir, model_bundle_fn)
            self.metric_names = self.build_metric_names()

            json_to_file(self.cfg.dict(), self.config_path)
            self.load_init_weights()
            self.load_checkpoint()
            self.opt = self.build_optimizer()
            self.setup_data()
            self.start_epoch = self.get_start_epoch()
            self.steps_per_epoch = len(
                self.train_ds) // self.cfg.solver.batch_sz
            self.step_scheduler = self.build_step_scheduler()
            self.epoch_scheduler = self.build_epoch_scheduler()
            self.setup_tensorboard()
    def setup_data(self):
        cfg = self.cfg
        batch_sz = cfg.solver.batch_sz
        num_workers = cfg.data.num_workers

        # download and unzip data
        if cfg.data.uri.startswith('s3://') or cfg.data.uri.startswith('/'):
            data_uri = cfg.data.uri
        else:
            data_uri = join(cfg.base_uri, cfg.data.uri)

        data_dirs = []
        zip_uris = [data_uri] if data_uri.endswith('.zip') else list_paths(
            data_uri, 'zip')
        for zip_ind, zip_uri in enumerate(zip_uris):
            zip_path = get_local_path(zip_uri, self.data_cache_dir)
            if not isfile(zip_path):
                zip_path = download_if_needed(zip_uri, self.data_cache_dir)
            with zipfile.ZipFile(zip_path, 'r') as zipf:
                data_dir = join(self.tmp_dir, 'data', str(zip_ind))
                data_dirs.append(data_dir)
                zipf.extractall(data_dir)

        # build datasets -- one per zip file and then merge them into a single dataset
        train_ds = []
        valid_ds = []
        test_ds = []
        for data_dir in data_dirs:
            train_dir = join(data_dir, 'train')
            valid_dir = join(data_dir, 'valid')

            transform = Compose(
                [Resize((cfg.data.img_sz, cfg.data.img_sz)),
                 ToTensor()])
            aug_transform = Compose([
                RandomHorizontalFlip(),
                RandomVerticalFlip(),
                ColorJitter(0.1, 0.1, 0.1, 0.1),
                Resize((cfg.data.img_sz, cfg.data.img_sz)),
                ToTensor()
            ])

            if isdir(train_dir):
                if cfg.overfit_mode:
                    train_ds.append(
                        ImageRegressionDataset(train_dir,
                                               cfg.data.class_names,
                                               transform=transform))
                else:
                    train_ds.append(
                        ImageRegressionDataset(train_dir,
                                               cfg.data.class_names,
                                               transform=aug_transform))

            if isdir(valid_dir):
                valid_ds.append(
                    ImageRegressionDataset(valid_dir,
                                           cfg.data.class_names,
                                           transform=transform))
                test_ds.append(
                    ImageRegressionDataset(valid_dir,
                                           cfg.data.class_names,
                                           transform=transform))

        train_ds, valid_ds, test_ds = \
            ConcatDataset(train_ds), ConcatDataset(valid_ds), ConcatDataset(test_ds)

        if cfg.overfit_mode:
            train_ds = Subset(train_ds, range(batch_sz))
            valid_ds = train_ds
            test_ds = train_ds
        elif cfg.test_mode:
            train_ds = Subset(train_ds, range(batch_sz))
            valid_ds = Subset(valid_ds, range(batch_sz))
            test_ds = Subset(test_ds, range(batch_sz))

        train_dl = DataLoader(train_ds,
                              shuffle=True,
                              batch_size=batch_sz,
                              num_workers=num_workers,
                              pin_memory=True)
        valid_dl = DataLoader(valid_ds,
                              shuffle=True,
                              batch_size=batch_sz,
                              num_workers=num_workers,
                              pin_memory=True)
        test_dl = DataLoader(test_ds,
                             shuffle=True,
                             batch_size=batch_sz,
                             num_workers=num_workers,
                             pin_memory=True)

        self.train_ds, self.valid_ds, self.test_ds = (train_ds, valid_ds,
                                                      test_ds)
        self.train_dl, self.valid_dl, self.test_dl = (train_dl, valid_dl,
                                                      test_dl)
Esempio n. 7
0
 def test_s3(self):
     download_dir = '/download_dir'
     uri = 's3://bucket/my/file.txt'
     path = get_local_path(uri, download_dir)
     self.assertEqual(path, '/download_dir/s3/bucket/my/file.txt')
Esempio n. 8
0
 def test_local(self):
     download_dir = '/download_dir'
     uri = '/my/file.txt'
     path = get_local_path(uri, download_dir)
     self.assertEqual(path, uri)
    def save(self, labels):
        """Save.

        Args:
            labels - (SemanticSegmentationLabels) labels to be saved
        """
        local_path = get_local_path(self.uri, self.tmp_dir)
        make_dir(local_path, use_dirname=True)

        transform = self.crs_transformer.get_affine_transform()
        crs = self.crs_transformer.get_image_crs()

        band_count = 1
        dtype = np.uint8
        if self.class_trans:
            band_count = 3

        mask = (np.zeros((self.extent.ymax, self.extent.xmax), dtype=np.uint8)
                if self.vector_output else None)

        # https://github.com/mapbox/rasterio/blob/master/docs/quickstart.rst
        # https://rasterio.readthedocs.io/en/latest/topics/windowed-rw.html
        with rasterio.open(local_path,
                           'w',
                           driver='GTiff',
                           height=self.extent.ymax,
                           width=self.extent.xmax,
                           count=band_count,
                           dtype=dtype,
                           transform=transform,
                           crs=crs) as dataset:
            for window in labels.get_windows():
                label_arr = labels.get_label_arr(window)
                window = window.intersection(self.extent)
                label_arr = label_arr[0:window.get_height(),
                                      0:window.get_width()]

                if mask is not None:
                    mask[window.ymin:window.ymax,
                         window.xmin:window.xmax] = label_arr

                window = window.rasterio_format()
                if self.class_trans:
                    rgb_labels = self.class_trans.class_to_rgb(label_arr)
                    for chan in range(3):
                        dataset.write_band(chan + 1,
                                           rgb_labels[:, :, chan],
                                           window=window)
                else:
                    img = label_arr.astype(dtype)
                    dataset.write_band(1, img, window=window)

        upload_or_copy(local_path, self.uri)

        if self.vector_output:
            import mask_to_polygons.vectorification as vectorification
            import mask_to_polygons.processing.denoise as denoise

            for vo in self.vector_output:
                denoise_radius = vo.denoise
                uri = vo.uri
                mode = vo.get_mode()
                class_id = vo.class_id
                class_mask = np.array(mask == class_id, dtype=np.uint8)
                local_geojson_path = get_local_path(uri, self.tmp_dir)

                def transform(x, y):
                    return self.crs_transformer.pixel_to_map((x, y))

                if denoise_radius > 0:
                    class_mask = denoise.denoise(class_mask, denoise_radius)

                if uri and mode == 'buildings':
                    geojson = vectorification.geojson_from_mask(
                        mask=class_mask,
                        transform=transform,
                        mode=mode,
                        min_aspect_ratio=vo.min_aspect_ratio,
                        min_area=vo.min_area,
                        width_factor=vo.element_width_factor,
                        thickness=vo.element_thickness)
                elif uri and mode == 'polygons':
                    geojson = vectorification.geojson_from_mask(
                        mask=class_mask, transform=transform, mode=mode)

                if local_geojson_path:
                    with open(local_geojson_path, 'w') as file_out:
                        file_out.write(geojson)
                        upload_or_copy(local_geojson_path, uri)
Esempio n. 10
0
def _zxy2geotiff(tile_schema, zoom, bounds, output_uri, make_cog=False):
    """Generates a GeoTIFF of a bounded region from a ZXY tile server.

    Args:
        tile_schema: (str) the URI schema for zxy tiles (ie. a slippy map tile server)
            of the form /tileserver-uri/{z}/{x}/{y}.png. If {-y} is used, the tiles
            are assumed to be indexed using TMS coordinates, where the y axis starts
            at the southernmost point. The URI can be for http, S3, or the local
            file system.
        zoom: (int) the zoom level to use when retrieving tiles
        bounds: (list) a list of length 4 containing min_lat, min_lng,
            max_lat, max_lng
        output_uri: (str) where to save the GeoTIFF. The URI can be for http, S3, or the
            local file system
    """
    min_lat, min_lng, max_lat, max_lng = bounds
    if min_lat >= max_lat:
        raise ValueError('min_lat must be < max_lat')
    if min_lng >= max_lng:
        raise ValueError('min_lng must be < max_lng')

    is_tms = False
    if '{-y}' in tile_schema:
        tile_schema = tile_schema.replace('{-y}', '{y}')
        is_tms = True

    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    # Get range of tiles that cover bounds.
    output_path = get_local_path(output_uri, tmp_dir)
    tile_sz = 256
    t = mercantile.tile(min_lng, max_lat, zoom)
    xmin, ymin = t.x, t.y
    t = mercantile.tile(max_lng, min_lat, zoom)
    xmax, ymax = t.x, t.y

    # The supplied bounds are contained within the "tile bounds" -- ie. the
    # bounds of the set of tiles that covers the supplied bounds. Therefore,
    # we need to crop out the imagery that lies within the supplied bounds.
    # We do this by computing a top, bottom, left, and right offset in pixel
    # units of the supplied bounds against the tile bounds. Getting the offsets
    # in pixel units involves converting lng/lat to web mercator units since we
    # assume that is the CRS of the tiles. These offsets are then used to crop
    # individual tiles and place them correctly into the output raster.
    nw_merc_x, nw_merc_y = lnglat2merc(min_lng, max_lat)
    left_pix_offset, top_pix_offset = merc2pixel(xmin, ymin, zoom, nw_merc_x,
                                                 nw_merc_y)

    se_merc_x, se_merc_y = lnglat2merc(max_lng, min_lat)
    se_left_pix_offset, se_top_pix_offset = merc2pixel(xmax, ymax, zoom,
                                                       se_merc_x, se_merc_y)
    right_pix_offset = tile_sz - se_left_pix_offset
    bottom_pix_offset = tile_sz - se_top_pix_offset

    uncropped_height = tile_sz * (ymax - ymin + 1)
    uncropped_width = tile_sz * (xmax - xmin + 1)
    height = uncropped_height - top_pix_offset - bottom_pix_offset
    width = uncropped_width - left_pix_offset - right_pix_offset

    transform = rasterio.transform.from_bounds(nw_merc_x, se_merc_y, se_merc_x,
                                               nw_merc_y, width, height)
    with rasterio.open(
            output_path,
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=3,
            crs='epsg:3857',
            transform=transform,
            dtype=rasterio.uint8) as dataset:
        out_x = 0
        for xi, x in enumerate(range(xmin, xmax + 1)):
            tile_xmin, tile_xmax = 0, tile_sz - 1
            if x == xmin:
                tile_xmin += left_pix_offset
            if x == xmax:
                tile_xmax -= right_pix_offset
            window_width = tile_xmax - tile_xmin + 1

            out_y = 0
            for yi, y in enumerate(range(ymin, ymax + 1)):
                tile_ymin, tile_ymax = 0, tile_sz - 1
                if y == ymin:
                    tile_ymin += top_pix_offset
                if y == ymax:
                    tile_ymax -= bottom_pix_offset
                window_height = tile_ymax - tile_ymin + 1

                # Convert from xyz to tms if needed.
                # https://gist.github.com/tmcw/4954720
                if is_tms:
                    y = (2**zoom) - y - 1
                tile_uri = tile_schema.format(x=x, y=y, z=zoom)
                tile_path = download_if_needed(tile_uri, tmp_dir)
                img = np.array(Image.open(tile_path))
                img = img[tile_ymin:tile_ymax + 1, tile_xmin:tile_xmax + 1, :]

                window = Window(out_x, out_y, window_width, window_height)
                dataset.write(
                    np.transpose(img[:, :, 0:3], (2, 0, 1)), window=window)
                out_y += window_height
            out_x += window_width

    if make_cog:
        create_cog(output_path, output_uri, tmp_dir)
    else:
        upload_or_copy(output_path, output_uri)