예제 #1
0
        def file_pyramid(pyramid_factory):
            opensearch_endpoint = layer_source_info.get(
                'opensearch_endpoint',
                ConfigParams().default_opensearch_endpoint)
            opensearch_collection_id = layer_source_info[
                'opensearch_collection_id']
            opensearch_link_titles = metadata.opensearch_link_titles
            root_path = layer_source_info['root_path']

            factory = pyramid_factory(opensearch_endpoint,
                                      opensearch_collection_id,
                                      opensearch_link_titles, root_path)

            if single_level:
                #TODO EP-3561 UTM is not always the native projection of a layer (PROBA-V), need to determine optimal projection
                return factory.datacube_seq(projected_polygons_native_crs,
                                            from_date, to_date,
                                            metadata_properties(),
                                            correlation_id, datacubeParams)
            else:
                if geometries:
                    return factory.pyramid_seq(projected_polygons.polygons(),
                                               projected_polygons.crs(),
                                               from_date, to_date,
                                               metadata_properties(),
                                               correlation_id)
                else:
                    return factory.pyramid_seq(extent, srs, from_date, to_date,
                                               metadata_properties(),
                                               correlation_id)
예제 #2
0
    def __init__(self):
        master_str = "local[*]"

        conf = geopyspark_conf(master=master_str, appName="test")
        conf.set('spark.kryoserializer.buffer.max', value='1G')
        conf.set('spark.ui.enabled', True)

        if ConfigParams().is_ci_context:
            conf.set(key='spark.driver.memory', value='2G')
            conf.set(key='spark.executor.memory', value='2G')

        self.pysc = SparkContext.getOrCreate(conf)

        self.first = np.zeros((1, 4, 4))
        self.first.fill(1)

        self.second = np.zeros((1, 4, 4))
        self.second.fill(2)

        self.extent = {'xmin': 0.0, 'ymin': 0.0, 'xmax': 4.0, 'ymax': 4.0}
        self.layout = {
            'layoutCols': 1,
            'layoutRows': 1,
            'tileCols': 4,
            'tileRows': 4
        }

        self.now = datetime.datetime.strptime("2017-09-25T11:37:00Z",
                                              '%Y-%m-%dT%H:%M:%SZ')
예제 #3
0
def _schedule_task(task_id: str, arguments: dict):
    task = {
        'task_id': task_id,
        'arguments': arguments
    }

    env = ConfigParams().async_task_handler_environment

    def encode(s: str) -> bytes:
        return s.encode('utf-8')

    producer = KafkaProducer(
        bootstrap_servers="epod-master1.vgt.vito.be:6668,epod-master2.vgt.vito.be:6668,epod-master3.vgt.vito.be:6668",
        security_protocol='PLAINTEXT',
        acks='all'
    )

    try:
        task_message = json.dumps(task)
        producer.send(topic="openeo-async-tasks",
                      value=encode(task_message),
                      headers=[('env', encode(env))] if env else None).get(timeout=120)

        _log.info(f"scheduled task {task_message} on env {env}")
    finally:
        producer.close()
 def __init__(self):
     self._root = '/openeo/services'
     self._hosts = ','.join(ConfigParams().zookeepernodes)
     with self._zk_client() as zk:
         zk.ensure_path(self._root)
     # Additional in memory storage of server instances that were registered in current process.
     self._services = {}
예제 #5
0
def main():
    repo = ZooKeeperUserDefinedProcessRepository(
        hosts=ConfigParams().zookeepernodes)

    user_id = 'vdboschj'
    process_graph_id = 'evi'
    udp_spec = {
        'id': process_graph_id,
        'process_graph': {
            'loadcollection1': {
                'process_id': 'load_collection',
                'arguments': {
                    'id': 'PROBAV_L3_S10_TOC_NDVI_333M'
                }
            }
        }
    }

    repo.save(user_id=user_id, process_id=process_graph_id, spec=udp_spec)

    udps = repo.get_for_user(user_id)

    for udp in udps:
        print(udp)

    repo.delete(user_id, process_graph_id)

    print(repo.get(user_id, process_graph_id))
예제 #6
0
def main():
    _log.info("ConfigParams(): {c}".format(c=ConfigParams()))

    parser = argparse.ArgumentParser(
        usage="OpenEO Cleaner",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--py4j-jarpath",
                        default="venv/share/py4j/py4j0.10.7.jar",
                        help='Path to the Py4J jar')
    parser.add_argument("--py4j-classpath",
                        default="geotrellis-extensions-2.2.0-SNAPSHOT.jar",
                        help='Classpath used to launch the Java Gateway')

    args = parser.parse_args()

    java_opts = [
        "-client",
        "-Dsoftware.amazon.awssdk.http.service.impl=software.amazon.awssdk.http.urlconnection.UrlConnectionSdkHttpService"
    ]

    java_gateway = JavaGateway.launch_gateway(jarpath=args.py4j_jarpath,
                                              classpath=args.py4j_classpath,
                                              javaopts=java_opts,
                                              die_on_exit=True)

    max_date = datetime.today() - timedelta(days=60)

    remove_batch_jobs_before(max_date, java_gateway.jvm)
    remove_secondary_services_before(max_date)
예제 #7
0
        def file_probav_pyramid():
            opensearch_endpoint = layer_source_info.get(
                'opensearch_endpoint',
                ConfigParams().default_opensearch_endpoint)

            return jvm.org.openeo.geotrellis.file.ProbaVPyramidFactory(opensearch_endpoint,
                                                                       layer_source_info.get('opensearch_collection_id'), layer_source_info.get('root_path'),jvm.geotrellis.raster.CellSize(cell_width, cell_height)) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices, correlation_id)
예제 #8
0
    def setup_batch_jobs():
        if not ConfigParams().is_ci_context:
            with JobRegistry() as job_registry:
                job_registry.ensure_paths()

            job_tracker = JobTracker(JobRegistry, principal="", keytab="")
            threading.Thread(target=job_tracker.loop_update_statuses,
                             daemon=True).start()
예제 #9
0
def zk_client(hosts: str = ','.join(ConfigParams().zookeepernodes)):
    zk = KazooClient(hosts)
    zk.start()

    try:
        yield zk
    finally:
        zk.stop()
        zk.close()
예제 #10
0
def _create_job_dir(job_dir: Path):
    logger.info("creating job dir {j!r} (parent dir: {p}))".format(
        j=job_dir, p=describe_path(job_dir.parent)))
    ensure_dir(job_dir)
    if not ConfigParams().is_kube_deploy:
        shutil.chown(job_dir, user=None, group='eodata')

    _add_permissions(job_dir, stat.S_ISGID
                     | stat.S_IWGRP)  # make children inherit this group
예제 #11
0
    def __init__(self):
        # TODO: do this with a config instead of hardcoding rules?
        self._service_registry = (InMemoryServiceRegistry()
                                  if ConfigParams().is_ci_context else
                                  ZooKeeperServiceRegistry())

        super().__init__(
            secondary_services=GpsSecondaryServices(
                service_registry=self._service_registry),
            catalog=get_layer_catalog(service_registry=self._service_registry),
            batch_jobs=GpsBatchJobs(),
        )
예제 #12
0
        def accumulo_pyramid():
            pyramidFactory = jvm.org.openeo.geotrellisaccumulo.PyramidFactory(
                "hdp-accumulo-instance",
                ','.join(ConfigParams().zookeepernodes))
            if layer_source_info.get("split", False):
                pyramidFactory.setSplitRanges(True)

            accumulo_layer_name = layer_source_info['data_id']
            nonlocal still_needs_band_filter
            still_needs_band_filter = bool(band_indices)
            return pyramidFactory.pyramid_seq(accumulo_layer_name, extent, srs,
                                              from_date, to_date)
예제 #13
0
def update_zookeeper(host: str, port: int, env: str) -> None:
    from kazoo.client import KazooClient
    from openeogeotrellis.configparams import ConfigParams

    cluster_id = 'openeo-' + env
    zk = KazooClient(hosts=','.join(ConfigParams().zookeepernodes))
    zk.start()

    try:
        Traefik(zk).add_load_balanced_server(cluster_id=cluster_id,
                                             server_id="0",
                                             host=host,
                                             port=port,
                                             environment=env)
    finally:
        zk.stop()
        zk.close()
예제 #14
0
        def accumulo_pyramid():
            pyramidFactory = jvm.org.openeo.geotrellisaccumulo.PyramidFactory(
                "hdp-accumulo-instance",
                ','.join(ConfigParams().zookeepernodes))
            if layer_source_info.get("split", False):
                pyramidFactory.setSplitRanges(True)

            accumulo_layer_name = layer_source_info['data_id']
            nonlocal still_needs_band_filter
            still_needs_band_filter = bool(band_indices)

            polygons = load_params.aggregate_spatial_geometries

            if polygons:
                projected_polygons = to_projected_polygons(jvm, polygons)
                return pyramidFactory.pyramid_seq(
                    accumulo_layer_name, projected_polygons.polygons(),
                    projected_polygons.crs(), from_date, to_date)
            else:
                return pyramidFactory.pyramid_seq(accumulo_layer_name, extent,
                                                  srs, from_date, to_date)
예제 #15
0
def get_layer_catalog(
    service_registry: AbstractServiceRegistry = None
) -> GeoPySparkLayerCatalog:
    """
    Get layer catalog (from JSON files)
    """
    catalog_files = ConfigParams().layer_catalog_metadata_files
    logger.info(
        "Reading layer catalog metadata from {f!r}".format(f=catalog_files[0]))
    metadata = read_json(catalog_files[0])
    if len(catalog_files) > 1:
        # Merge metadata recursively
        metadata = {l["id"]: l for l in metadata}
        for path in catalog_files[1:]:
            logger.info(
                "Updating layer catalog metadata from {f!r}".format(f=path))
            updates = {l["id"]: l for l in read_json(path)}
            metadata = dict_merge_recursive(metadata, updates, overwrite=True)
        metadata = list(metadata.values())

    return GeoPySparkLayerCatalog(all_metadata=metadata,
                                  service_registry=service_registry
                                  or InMemoryServiceRegistry())
예제 #16
0
    def update_statuses(self) -> None:
        with self._job_registry() as registry:
            registry.ensure_paths()

            jobs_to_track = registry.get_running_jobs()

            for job_info in jobs_to_track:
                try:
                    job_id, user_id = job_info['job_id'], job_info['user_id']
                    application_id, current_status = job_info[
                        'application_id'], job_info['status']

                    if application_id:
                        try:
                            if ConfigParams().is_kube_deploy:
                                from openeogeotrellis.utils import s3_client, download_s3_dir
                                state, start_time, finish_time = JobTracker._kube_status(
                                    job_id, user_id)

                                new_status = JobTracker._kube_status_parser(
                                    state)

                                registry.patch(job_id,
                                               user_id,
                                               status=new_status,
                                               started=start_time,
                                               finished=finish_time)

                                if current_status != new_status:
                                    _log.info(
                                        "changed job %s status from %s to %s" %
                                        (job_id, current_status, new_status),
                                        extra={'job_id': job_id})

                                if state == "COMPLETED":
                                    # TODO: do we support SHub batch processes in this environment? The AWS
                                    #  credentials conflict.
                                    download_s3_dir(
                                        "OpenEO-data",
                                        "batch_jobs/{j}".format(j=job_id))

                                    result_metadata = self._batch_jobs.get_results_metadata(
                                        job_id, user_id)
                                    registry.patch(job_id, user_id,
                                                   **result_metadata)

                                    registry.mark_done(job_id, user_id)
                                    _log.info("marked %s as done" % job_id,
                                              extra={'job_id': job_id})
                            else:
                                state, final_state, start_time, finish_time, aggregate_resource_allocation =\
                                    JobTracker._yarn_status(application_id)

                                memory_time_megabyte_seconds, cpu_time_seconds =\
                                    JobTracker._parse_resource_allocation(aggregate_resource_allocation)

                                new_status = JobTracker._to_openeo_status(
                                    state, final_state)

                                registry.patch(
                                    job_id,
                                    user_id,
                                    status=new_status,
                                    started=JobTracker.
                                    _to_serializable_datetime(start_time),
                                    finished=JobTracker.
                                    _to_serializable_datetime(finish_time),
                                    memory_time_megabyte_seconds=
                                    memory_time_megabyte_seconds,
                                    cpu_time_seconds=cpu_time_seconds)

                                if current_status != new_status:
                                    _log.info(
                                        "changed job %s status from %s to %s" %
                                        (job_id, current_status, new_status),
                                        extra={'job_id': job_id})

                                if final_state != "UNDEFINED":
                                    result_metadata = self._batch_jobs.get_results_metadata(
                                        job_id, user_id)
                                    # TODO: skip patching the job znode and read from this file directly?
                                    registry.patch(job_id, user_id,
                                                   **result_metadata)

                                    if new_status == 'finished':
                                        registry.remove_dependencies(
                                            job_id, user_id)

                                        dependency_sources = JobRegistry.get_dependency_sources(
                                            job_info)

                                        if dependency_sources:
                                            async_task.schedule_delete_batch_process_dependency_sources(
                                                job_id, dependency_sources)

                                    registry.mark_done(job_id, user_id)

                                    _log.info("marked %s as done" % job_id,
                                              extra={
                                                  'job_id':
                                                  job_id,
                                                  'area':
                                                  result_metadata.get('area'),
                                                  'unique_process_ids':
                                                  result_metadata.get(
                                                      'unique_process_ids'),
                                                  'cpu_time_seconds':
                                                  cpu_time_seconds
                                              })
                        except JobTracker._UnknownApplicationIdException:
                            registry.mark_done(job_id, user_id)
                except Exception:
                    _log.warning(
                        "resuming with remaining jobs after failing to handle batch job {j}:\n{e}"
                        .format(j=job_id, e=traceback.format_exc()),
                        extra={'job_id': job_id})
                    registry.set_status(job_id, user_id, 'error')
                    registry.mark_done(job_id, user_id)
예제 #17
0
if __name__ == '__main__':
    import argparse

    logging.basicConfig(level=logging.INFO)
    openeogeotrellis.backend.logger.setLevel(logging.DEBUG)

    handler = logging.StreamHandler(stream=sys.stdout)
    handler.formatter = JsonFormatter(
        "%(asctime)s %(name)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S%z")

    root_logger = logging.getLogger()
    root_logger.addHandler(handler)

    _log.info("ConfigParams(): {c}".format(c=ConfigParams()))

    parser = argparse.ArgumentParser(
        usage="OpenEO JobTracker",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--principal",
                        default="*****@*****.**",
                        help="Principal to be used to login to KDC")
    parser.add_argument(
        "--keytab",
        default="openeo-deploy/mep/openeo.keytab",
        help=
        "The full path to the file that contains the keytab for the principal")
    args = parser.parse_args()

    try:
예제 #18
0
    def creodias(
            self,
            projected_polygons,
            from_date: str,
            to_date: str,
            collection_id: str = "Sentinel1",
            correlation_id: str = "NA",
            sar_backscatter_arguments: SarBackscatterArgs = SarBackscatterArgs(
            ),
            bands=None,
            zoom=0,  # TODO: what to do with zoom? It is not used at the moment.
            result_dtype="float32",
            extra_properties={}) -> Dict[int, geopyspark.TiledRasterLayer]:
        """
        Implementation of S1 backscatter calculation with Orfeo in Creodias environment
        """

        # Initial argument checking
        bands = bands or ["VH", "VV"]
        sar_calibration_lut = self._get_sar_calibration_lut(
            sar_backscatter_arguments.coefficient)
        if sar_backscatter_arguments.mask:
            raise FeatureUnsupportedException(
                "sar_backscatter: mask band is not supported")
        if sar_backscatter_arguments.contributing_area:
            raise FeatureUnsupportedException(
                "sar_backscatter: contributing_area band is not supported")
        if sar_backscatter_arguments.local_incidence_angle:
            raise FeatureUnsupportedException(
                "sar_backscatter: local_incidence_angle band is not supported")
        if sar_backscatter_arguments.ellipsoid_incidence_angle:
            raise FeatureUnsupportedException(
                "sar_backscatter: ellipsoid_incidence_angle band is not supported"
            )

        # Tile size to use in the TiledRasterLayer.
        tile_size = sar_backscatter_arguments.options.get(
            "tile_size", self._DEFAULT_TILE_SIZE)
        orfeo_memory = sar_backscatter_arguments.options.get("otb_memory", 256)

        # Geoid for orthorectification: get from options, fallback on config.
        elev_geoid = sar_backscatter_arguments.options.get(
            "elev_geoid") or ConfigParams().s1backscatter_elev_geoid
        elev_default = sar_backscatter_arguments.options.get("elev_default")
        logger.info(f"elev_geoid: {elev_geoid!r}")

        noise_removal = bool(sar_backscatter_arguments.noise_removal)
        debug_mode = smart_bool(sar_backscatter_arguments.options.get("debug"))

        feature_pyrdd, layer_metadata_py = self._build_feature_rdd(
            collection_id=collection_id,
            projected_polygons=projected_polygons,
            from_date=from_date,
            to_date=to_date,
            extra_properties=extra_properties,
            tile_size=tile_size,
            zoom=zoom,
            correlation_id=correlation_id)
        if debug_mode:
            self._debug_show_rdd_info(feature_pyrdd)

        # Group multiple tiles by product id
        def process_feature(feature: dict) -> Tuple[str, dict]:
            creo_path = feature["feature"]["id"]
            return creo_path, {
                "key": feature["key"],
                "key_extent": feature["key_extent"],
                "bbox": feature["feature"]["bbox"],
                "key_epsg": feature["metadata"]["crs_epsg"]
            }

        per_product = feature_pyrdd.map(
            process_feature).groupByKey().mapValues(list)

        # TODO: still split if full layout extent is too large for processing as a whole?

        # Apply Orfeo processing over product files as whole and splice up in tiles after that
        @epsel.ensure_info_logging
        @TimingLogger(title="process_product", logger=logger)
        def process_product(product: Tuple[str, List[dict]]):
            import faulthandler
            faulthandler.enable()
            creo_path, features = product

            # Short ad-hoc product id for logging purposes.
            prod_id = re.sub(r"[^A-Z0-9]", "", creo_path.upper())[-10:]
            log_prefix = f"p{os.getpid()}-prod{prod_id}"
            logger.info(f"{log_prefix} creo path {creo_path}")
            logger.info(
                f"{log_prefix} sar_backscatter_arguments: {sar_backscatter_arguments!r}"
            )

            creo_path = pathlib.Path(creo_path)
            if not creo_path.exists():
                raise OpenEOApiException("Creo path does not exist")

            # Get whole extent of tile layout
            col_min = min(f["key"]["col"] for f in features)
            col_max = max(f["key"]["col"] for f in features)
            cols = col_max - col_min + 1
            row_min = min(f["key"]["row"] for f in features)
            row_max = max(f["key"]["row"] for f in features)
            rows = row_max - row_min + 1
            instants = set(f["key"]["instant"] for f in features)
            assert len(instants) == 1, f"Not single instant: {instants}"
            instant = instants.pop()
            logger.info(
                f"{log_prefix} Layout key extent: col[{col_min}:{col_max}] row[{row_min}:{row_max}]"
                f" ({cols}x{rows}={cols * rows} tiles) instant[{instant}].")

            layout_extent = get_total_extent(features)

            key_epsgs = set(f["key_epsg"] for f in features)
            assert len(key_epsgs) == 1, f"Multiple key CRSs {key_epsgs}"
            layout_epsg = key_epsgs.pop()
            layout_width_px = tile_size * (col_max - col_min + 1)
            layout_height_px = tile_size * (row_max - row_min + 1)
            logger.info(
                f"{log_prefix} Layout extent {layout_extent} EPSG {layout_epsg}:"
                f" {layout_width_px}x{layout_height_px}px")

            band_tiffs = S1BackscatterOrfeo._creo_scan_for_band_tiffs(
                creo_path, log_prefix)

            dem_dir_context = S1BackscatterOrfeo._get_dem_dir_context(
                sar_backscatter_arguments=sar_backscatter_arguments,
                extent=layout_extent,
                epsg=layout_epsg)

            msg = f"{log_prefix} Process {creo_path} "
            with TimingLogger(title=msg,
                              logger=logger), dem_dir_context as dem_dir:
                # Allocate numpy array tile
                orfeo_bands = numpy.zeros(
                    (len(bands), layout_height_px, layout_width_px),
                    dtype=result_dtype)

                for b, band in enumerate(bands):
                    if band.lower() not in band_tiffs:
                        raise OpenEOApiException(f"No tiff for band {band}")
                    data, nodata = S1BackscatterOrfeoV2._orfeo_pipeline(
                        input_tiff=band_tiffs[band.lower()],
                        extent=layout_extent,
                        extent_epsg=layout_epsg,
                        dem_dir=dem_dir,
                        extent_width_px=layout_width_px,
                        extent_height_px=layout_height_px,
                        sar_calibration_lut=sar_calibration_lut,
                        noise_removal=noise_removal,
                        elev_geoid=elev_geoid,
                        elev_default=elev_default,
                        log_prefix=f"{log_prefix}-{band}",
                        orfeo_memory=orfeo_memory)
                    orfeo_bands[b] = data

                if sar_backscatter_arguments.options.get("to_db", False):
                    # TODO: keep this "to_db" shortcut feature or drop it
                    #       and require user to use standard openEO functionality (`apply` based conversion)?
                    logger.info(
                        f"{log_prefix} Converting backscatter intensity to decibel"
                    )
                    orfeo_bands = 10 * numpy.log10(orfeo_bands)

                # Split orfeo output in tiles
                logger.info(
                    f"{log_prefix} Split {orfeo_bands.shape} in tiles of {tile_size}"
                )
                cell_type = geopyspark.CellType(orfeo_bands.dtype.name)
                tiles = []
                for c in range(col_max - col_min + 1):
                    for r in range(row_max - row_min + 1):
                        col = col_min + c
                        row = row_min + r
                        key = geopyspark.SpaceTimeKey(
                            col=col,
                            row=row,
                            instant=_instant_ms_to_day(instant))
                        tile = orfeo_bands[:,
                                           r * tile_size:(r + 1) * tile_size,
                                           c * tile_size:(c + 1) * tile_size]
                        if not (tile == nodata).all():
                            logger.info(
                                f"{log_prefix} Create Tile for key {key} from {tile.shape}"
                            )
                            tile = geopyspark.Tile(tile,
                                                   cell_type,
                                                   no_data_value=nodata)
                            tiles.append((key, tile))

            logger.info(
                f"{log_prefix} Layout extent split in {len(tiles)} tiles")
            return tiles

        paths = list(per_product.keys().collect())

        def partitionByPath(tuple):
            try:
                return paths.index(tuple)
            except Exception as e:
                hashPartitioner = pyspark.rdd.portable_hash
                return hashPartitioner(tuple)

        grouped = per_product.partitionBy(per_product.count(), partitionByPath)
        tile_rdd = grouped.flatMap(process_product)
        if result_dtype:
            layer_metadata_py.cell_type = result_dtype
        logger.info(
            "Constructing TiledRasterLayer from numpy rdd, with metadata {m!r}"
            .format(m=layer_metadata_py))
        tile_layer = geopyspark.TiledRasterLayer.from_numpy_rdd(
            layer_type=geopyspark.LayerType.SPACETIME,
            numpy_rdd=tile_rdd,
            metadata=layer_metadata_py)
        return {zoom: tile_layer}
예제 #19
0
    def _get_process_function(sar_backscatter_arguments, result_dtype, bands):

        # Tile size to use in the TiledRasterLayer.
        tile_size = sar_backscatter_arguments.options.get(
            "tile_size", S1BackscatterOrfeo._DEFAULT_TILE_SIZE)
        noise_removal = bool(sar_backscatter_arguments.noise_removal)

        # Geoid for orthorectification: get from options, fallback on config.
        elev_geoid = sar_backscatter_arguments.options.get(
            "elev_geoid") or ConfigParams().s1backscatter_elev_geoid
        elev_default = sar_backscatter_arguments.options.get("elev_default")
        logger.info(f"elev_geoid: {elev_geoid!r}")

        sar_calibration_lut = S1BackscatterOrfeo._get_sar_calibration_lut(
            sar_backscatter_arguments.coefficient)

        @epsel.ensure_info_logging
        @TimingLogger(title="process_feature", logger=logger)
        def process_feature(product: Tuple[str, List[dict]]):
            import faulthandler
            faulthandler.enable()
            creo_path, features = product

            prod_id = re.sub(r"[^A-Z0-9]", "", creo_path.upper())[-10:]
            log_prefix = f"p{os.getpid()}-prod{prod_id}"
            print(f"{log_prefix} creo path {creo_path}")
            logger.info(
                f"{log_prefix} sar_backscatter_arguments: {sar_backscatter_arguments!r}"
            )

            layout_extent = get_total_extent(features)
            key_epsgs = set(f["key_epsg"] for f in features)
            assert len(key_epsgs) == 1, f"Multiple key CRSs {key_epsgs}"
            layout_epsg = key_epsgs.pop()

            dem_dir_context = S1BackscatterOrfeo._get_dem_dir_context(
                sar_backscatter_arguments=sar_backscatter_arguments,
                extent=layout_extent,
                epsg=layout_epsg)

            creo_path = pathlib.Path(creo_path)

            band_tiffs = S1BackscatterOrfeo._creo_scan_for_band_tiffs(
                creo_path, log_prefix)

            resultlist = []

            with dem_dir_context as dem_dir:

                for feature in features:
                    col, row, instant = (feature["key"][k]
                                         for k in ["col", "row", "instant"])

                    key_ext = feature["key_extent"]
                    key_epsg = layout_epsg

                    logger.info(
                        f"{log_prefix} Feature creo path: {creo_path}, key {key_ext} (EPSG {key_epsg})"
                    )
                    logger.info(
                        f"{log_prefix} sar_backscatter_arguments: {sar_backscatter_arguments!r}"
                    )
                    if not creo_path.exists():
                        raise OpenEOApiException("Creo path does not exist")

                    msg = f"{log_prefix} Process {creo_path} and load into geopyspark Tile"
                    with TimingLogger(title=msg, logger=logger):
                        # Allocate numpy array tile
                        tile_data = numpy.zeros(
                            (len(bands), tile_size, tile_size),
                            dtype=result_dtype)

                        for b, band in enumerate(bands):
                            if band.lower() not in band_tiffs:
                                raise OpenEOApiException(
                                    f"No tiff for band {band}")
                            data, nodata = S1BackscatterOrfeo._orfeo_pipeline(
                                input_tiff=band_tiffs[band.lower()],
                                extent=key_ext,
                                extent_epsg=key_epsg,
                                dem_dir=dem_dir,
                                extent_width_px=tile_size,
                                extent_height_px=tile_size,
                                sar_calibration_lut=sar_calibration_lut,
                                noise_removal=noise_removal,
                                elev_geoid=elev_geoid,
                                elev_default=elev_default,
                                log_prefix=f"{log_prefix}-{band}")
                            tile_data[b] = data

                        if sar_backscatter_arguments.options.get(
                                "to_db", False):
                            # TODO: keep this "to_db" shortcut feature or drop it
                            #       and require user to use standard openEO functionality (`apply` based conversion)?
                            logger.info(
                                f"{log_prefix} Converting backscatter intensity to decibel"
                            )
                            tile_data = 10 * numpy.log10(tile_data)

                        key = geopyspark.SpaceTimeKey(
                            row=row,
                            col=col,
                            instant=_instant_ms_to_day(instant))
                        cell_type = geopyspark.CellType(tile_data.dtype.name)
                        logger.info(
                            f"{log_prefix} Create Tile for key {key} from {tile_data.shape}"
                        )
                        tile = geopyspark.Tile(tile_data,
                                               cell_type,
                                               no_data_value=nodata)
                        resultlist.append((key, tile))

            return resultlist

        return process_feature
예제 #20
0
 def __init__(
     self,
     root_path: str = ConfigParams().batch_jobs_zookeeper_root_path,
     zookeeper_hosts: str = ','.join(ConfigParams().zookeepernodes)):
     self._root = root_path
     self._zk = KazooClient(hosts=zookeeper_hosts)
 def __init__(self,
              zookeeper_hosts: str = ','.join(
                  ConfigParams().zookeepernodes)):
     self._root = '/openeo/jobs'
     self._zk = KazooClient(hosts=zookeeper_hosts)
예제 #22
0
def run_job(job_specification,
            output_file: Path,
            metadata_file: Path,
            api_version,
            job_dir,
            dependencies: dict,
            user_id: str = None):
    logger.info(f"Job spec: {json.dumps(job_specification,indent=1)}")
    process_graph = job_specification['process_graph']

    backend_implementation = GeoPySparkBackendImplementation()
    logger.info(f"Using backend implementation {backend_implementation}")
    correlation_id = str(uuid.uuid4())
    logger.info(f"Correlation id: {correlation_id}")
    env = EvalEnv({
        'version': api_version or "1.0.0",
        'pyramid_levels': 'highest',
        'user': User(user_id=user_id),
        'require_bounds': True,
        'correlation_id': correlation_id,
        'dependencies': dependencies,
        "backend_implementation": backend_implementation,
    })
    tracer = DryRunDataTracer()
    logger.info("Starting process graph evaluation")
    result = ProcessGraphDeserializer.evaluate(process_graph,
                                               env=env,
                                               do_dry_run=tracer)
    logger.info("Evaluated process graph, result (type {t}): {r!r}".format(
        t=type(result), r=result))

    if isinstance(result, DelayedVector):
        geojsons = (mapping(geometry) for geometry in result.geometries)
        result = JSONResult(geojsons)

    if isinstance(result, DriverDataCube):
        format_options = job_specification.get('output', {})
        format_options["batch_mode"] = True
        result = ImageCollectionResult(cube=result,
                                       format='GTiff',
                                       options=format_options)

    if not isinstance(result, SaveResult):  # Assume generic JSON result
        result = JSONResult(result)

    global_metadata_attributes = {
        "title": job_specification.get("title", ""),
        "description": job_specification.get("description", ""),
        "institution": "openEO platform - Geotrellis backend: " + __version__
    }

    assets_metadata = None
    if ('write_assets' in dir(result)):
        result.options["batch_mode"] = True
        result.options["file_metadata"] = global_metadata_attributes
        if (result.options.get("sample_by_feature")):
            geoms = tracer.get_geometries("filter_spatial")
            if len(geoms) > 1:
                logger.warning(
                    "Multiple aggregate_spatial geometries: {c}".format(
                        c=len(geoms)))
            elif len(geoms) == 0:
                logger.warning(
                    "sample_by_feature enabled, but no geometries found. They can be specified using filter_spatial."
                )
            else:
                result.options["geometries"] = geoms[0]
            if (result.options["geometries"] == None):
                logger.error(
                    "samply_by_feature was set, but no geometries provided through filter_spatial. Make sure to provide geometries."
                )
        assets_metadata = result.write_assets(str(output_file))
        for name, asset in assets_metadata.items():
            _add_permissions(Path(asset["href"]), stat.S_IWGRP)
        logger.info("wrote image collection to %s" % output_file)

    elif isinstance(result, ImageCollectionResult):
        result.options["batch_mode"] = True
        result.save_result(filename=str(output_file))
        _add_permissions(output_file, stat.S_IWGRP)
        logger.info("wrote image collection to %s" % output_file)
    elif isinstance(result, MultipleFilesResult):
        result.reduce(output_file, delete_originals=True)
        _add_permissions(output_file, stat.S_IWGRP)
        logger.info("reduced %d files to %s" %
                    (len(result.files), output_file))
    elif isinstance(result, NullResult):
        logger.info("skipping output file %s" % output_file)
    else:
        raise NotImplementedError(
            "unsupported result type {r}".format(r=type(result)))

    if any(card4l
           for _, card4l in dependencies.values()):  # TODO: clean this up
        logger.debug("awaiting Sentinel Hub CARD4L data...")

        s3_service = get_jvm().org.openeo.geotrellissentinelhub.S3Service()

        poll_interval_secs = 10
        max_delay_secs = 600

        card4l_dependencies = [
            (collection_id, source_location)
            for (collection_id,
                 metadata_properties), (source_location,
                                        card4l) in dependencies.items()
            if card4l
        ]

        for collection_id, source_location in card4l_dependencies:
            uri_parts = urlparse(source_location)
            bucket_name = uri_parts.hostname
            request_group_id = uri_parts.path[1:]

            try:
                # FIXME: incorporate collection_id and metadata_properties to make sure the files don't clash
                s3_service.download_stac_data(bucket_name, request_group_id,
                                              str(job_dir), poll_interval_secs,
                                              max_delay_secs)
                logger.info("downloaded CARD4L data in {b}/{g} to {d}".format(
                    b=bucket_name, g=request_group_id, d=job_dir))
            except Py4JJavaError as e:
                java_exception = e.java_exception

                if (java_exception.getClass().getName(
                ) == 'org.openeo.geotrellissentinelhub.S3Service$StacMetadataUnavailableException'
                    ):
                    logger.warning(
                        "could not find CARD4L metadata to download from s3://{b}/{r} after {d}s"
                        .format(b=bucket_name,
                                r=request_group_id,
                                d=max_delay_secs))
                else:
                    raise e

        _transform_stac_metadata(job_dir)

    unique_process_ids = CollectUniqueProcessIdsVisitor().accept_process_graph(
        process_graph).process_ids

    _export_result_metadata(tracer=tracer,
                            result=result,
                            output_file=output_file,
                            metadata_file=metadata_file,
                            unique_process_ids=unique_process_ids,
                            asset_metadata=assets_metadata)

    if ConfigParams().is_kube_deploy:
        import boto3
        from openeogeotrellis.utils import s3_client

        bucket = os.environ.get('SWIFT_BUCKET')
        s3_instance = s3_client()

        logger.info("Writing results to object storage")
        for file in os.listdir(job_dir):
            full_path = str(job_dir) + "/" + file
            s3_instance.upload_file(full_path, bucket, full_path.strip("/"))
예제 #23
0
def get_layer_catalog(opensearch_enrich=False) -> GeoPySparkLayerCatalog:
    """
    Get layer catalog (from JSON files)
    """
    metadata: Dict[str, dict] = {}

    def read_catalog_file(catalog_file) -> Dict[str, dict]:
        return {coll["id"]: coll for coll in read_json(catalog_file)}

    catalog_files = ConfigParams().layer_catalog_metadata_files
    for path in catalog_files:
        logger.info(f"Reading layer catalog metadata from {path}")
        metadata = dict_merge_recursive(metadata,
                                        read_catalog_file(path),
                                        overwrite=True)

    if opensearch_enrich:
        opensearch_metadata = {}
        sh_collection_metadatas = None
        opensearch_instances = {}

        def opensearch_instance(endpoint: str) -> OpenSearch:
            endpoint = endpoint.lower()
            opensearch = opensearch_instances.get(os_endpoint)

            if opensearch is not None:
                return opensearch

            if "oscars" in endpoint or "terrascope" in endpoint or "vito.be" in endpoint:
                opensearch = OpenSearchOscars(endpoint=endpoint)
            elif "creodias" in endpoint:
                opensearch = OpenSearchCreodias(endpoint=endpoint)
            else:
                raise ValueError(endpoint)

            opensearch_instances[endpoint] = opensearch
            return opensearch

        for cid, collection_metadata in metadata.items():
            data_source = deep_get(collection_metadata,
                                   "_vito",
                                   "data_source",
                                   default={})
            os_cid = data_source.get("opensearch_collection_id")
            if os_cid:
                os_endpoint = data_source.get(
                    "opensearch_endpoint") or ConfigParams(
                    ).default_opensearch_endpoint
                logger.info(
                    f"Updating {cid} metadata from {os_endpoint}:{os_cid}")
                try:
                    opensearch_metadata[cid] = opensearch_instance(
                        os_endpoint).get_metadata(collection_id=os_cid)
                except Exception:
                    logger.warning(traceback.format_exc())
            elif data_source.get("type") == "sentinel-hub":
                sh_cid = data_source.get("collection_id")

                if sh_cid is None:
                    continue

                try:
                    sh_stac_endpoint = "https://collections.eurodatacube.com/stac/index.json"

                    if sh_collection_metadatas is None:
                        sh_collections = requests.get(sh_stac_endpoint).json()
                        sh_collection_metadatas = [
                            requests.get(c["link"]).json()
                            for c in sh_collections
                        ]

                    sh_metadata = next(
                        filter(lambda m: m["datasource_type"] == sh_cid,
                               sh_collection_metadatas))
                    logger.info(
                        f"Updating {cid} metadata from {sh_stac_endpoint}:{sh_metadata['id']}"
                    )
                    opensearch_metadata[cid] = sh_metadata
                    if not data_source.get("endpoint"):
                        endpoint = opensearch_metadata[cid]["providers"][0][
                            "url"]
                        endpoint = endpoint if endpoint.startswith(
                            "http") else "https://{}".format(endpoint)
                        data_source["endpoint"] = endpoint
                    data_source["dataset_id"] = data_source.get(
                        "dataset_id"
                    ) or opensearch_metadata[cid]["datasource_type"]
                except StopIteration:
                    logger.warning(
                        f"No STAC data available for collection with id {sh_cid}"
                    )

        if opensearch_metadata:
            metadata = dict_merge_recursive(opensearch_metadata,
                                            metadata,
                                            overwrite=True)

    metadata = _merge_layers_with_common_name(metadata)

    return GeoPySparkLayerCatalog(all_metadata=list(metadata.values()))
예제 #24
0
    def load_collection(
            self, collection_id: str,
            viewing_parameters: dict) -> 'GeotrellisTimeSeriesImageCollection':
        logger.info("Creating layer for {c} with viewingParameters {v}".format(
            c=collection_id, v=viewing_parameters))

        # TODO is it necessary to do this kerberos stuff here?
        kerberos()

        metadata = CollectionMetadata(
            self.get_collection_metadata(collection_id, strip_private=False))
        layer_source_info = metadata.get("_vito", "data_source", default={})
        layer_source_type = layer_source_info.get("type", "Accumulo").lower()
        logger.info("Layer source type: {s!r}".format(s=layer_source_type))

        import geopyspark as gps
        from_date = normalize_date(viewing_parameters.get("from", None))
        to_date = normalize_date(viewing_parameters.get("to", None))

        left = viewing_parameters.get("left", None)
        right = viewing_parameters.get("right", None)
        top = viewing_parameters.get("top", None)
        bottom = viewing_parameters.get("bottom", None)
        srs = viewing_parameters.get("srs", None)
        bands = viewing_parameters.get("bands", None)
        band_indices = [metadata.get_band_index(b)
                        for b in bands] if bands else None
        logger.info("band_indices: {b!r}".format(b=band_indices))
        # TODO: avoid this `still_needs_band_filter` ugliness.
        #       Also see https://github.com/Open-EO/openeo-geopyspark-driver/issues/29
        still_needs_band_filter = False
        pysc = gps.get_spark_context()
        extent = None

        gateway = JavaGateway(
            eager_load=True,
            gateway_parameters=pysc._gateway.gateway_parameters)
        jvm = gateway.jvm

        spatial_bounds_present = left is not None and right is not None and top is not None and bottom is not None

        if spatial_bounds_present:
            extent = jvm.geotrellis.vector.Extent(float(left), float(bottom),
                                                  float(right), float(top))
        elif ConfigParams().require_bounds:
            raise ProcessGraphComplexityException
        else:
            srs = "EPSG:4326"
            extent = jvm.geotrellis.vector.Extent(-180.0, -90.0, 180.0, 90.0)

        def accumulo_pyramid():
            pyramidFactory = jvm.org.openeo.geotrellisaccumulo.PyramidFactory(
                "hdp-accumulo-instance",
                ','.join(ConfigParams().zookeepernodes))
            if layer_source_info.get("split", False):
                pyramidFactory.setSplitRanges(True)

            accumulo_layer_name = layer_source_info['data_id']
            nonlocal still_needs_band_filter
            still_needs_band_filter = bool(band_indices)
            return pyramidFactory.pyramid_seq(accumulo_layer_name, extent, srs,
                                              from_date, to_date)

        def s3_pyramid():
            endpoint = layer_source_info['endpoint']
            region = layer_source_info['region']
            bucket_name = layer_source_info['bucket_name']
            nonlocal still_needs_band_filter
            still_needs_band_filter = bool(band_indices)
            return jvm.org.openeo.geotrelliss3.PyramidFactory(endpoint, region, bucket_name) \
                .pyramid_seq(extent, srs, from_date, to_date)

        def s3_jp2_pyramid():
            endpoint = layer_source_info['endpoint']
            region = layer_source_info['region']

            return jvm.org.openeo.geotrelliss3.Jp2PyramidFactory(endpoint, region) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def file_s2_radiometry_pyramid():
            return jvm.org.openeo.geotrellis.file.Sentinel2RadiometryPyramidFactory() \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def file_s2_pyramid():
            oscars_collection_id = layer_source_info['oscars_collection_id']
            oscars_link_titles = metadata.band_names
            root_path = layer_source_info['root_path']

            filtered_link_titles = [
                oscars_link_titles[i] for i in band_indices
            ] if band_indices else oscars_link_titles

            return jvm.org.openeo.geotrellis.file.Sentinel2PyramidFactory(
                oscars_collection_id, filtered_link_titles,
                root_path).pyramid_seq(extent, srs, from_date, to_date)

        def geotiff_pyramid():
            glob_pattern = layer_source_info['glob_pattern']
            date_regex = layer_source_info['date_regex']

            new_pyramid_factory = jvm.org.openeo.geotrellis.geotiff.PyramidFactory.from_disk(
                glob_pattern, date_regex)

            return self._geotiff_pyramid_factories.setdefault(collection_id, new_pyramid_factory) \
                .pyramid_seq(extent, srs, from_date, to_date)

        def file_s1_coherence_pyramid():
            return jvm.org.openeo.geotrellis.file.Sentinel1CoherencePyramidFactory() \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def sentinel_hub_s1_pyramid():
            return jvm.org.openeo.geotrellissentinelhub.S1PyramidFactory(layer_source_info.get('uuid')) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def sentinel_hub_s2_l1c_pyramid():
            return jvm.org.openeo.geotrellissentinelhub.S2L1CPyramidFactory(layer_source_info.get('uuid')) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def sentinel_hub_s2_l2a_pyramid():
            return jvm.org.openeo.geotrellissentinelhub.S2L2APyramidFactory(layer_source_info.get('uuid')) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def sentinel_hub_l8_pyramid():
            return jvm.org.openeo.geotrellissentinelhub.L8PyramidFactory(layer_source_info.get('uuid')) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        logger.info("loading pyramid {s}".format(s=layer_source_type))
        if layer_source_type == 's3':
            pyramid = s3_pyramid()
        elif layer_source_type == 's3-jp2':
            pyramid = s3_jp2_pyramid()
        elif layer_source_type == 'file-s2-radiometry':
            pyramid = file_s2_radiometry_pyramid()
        elif layer_source_type == 'file-s2':
            pyramid = file_s2_pyramid()
        elif layer_source_type == 'geotiff':
            pyramid = geotiff_pyramid()
        elif layer_source_type == 'file-s1-coherence':
            pyramid = file_s1_coherence_pyramid()
        elif layer_source_type == 'sentinel-hub-s1':
            pyramid = sentinel_hub_s1_pyramid()
        elif layer_source_type == 'sentinel-hub-s2-l1c':
            pyramid = sentinel_hub_s2_l1c_pyramid()
        elif layer_source_type == 'sentinel-hub-s2-l2a':
            pyramid = sentinel_hub_s2_l2a_pyramid()
        elif layer_source_type == 'sentinel-hub-l8':
            pyramid = sentinel_hub_l8_pyramid()
        else:
            pyramid = accumulo_pyramid()

        temporal_tiled_raster_layer = jvm.geopyspark.geotrellis.TemporalTiledRasterLayer
        option = jvm.scala.Option
        levels = {
            pyramid.apply(index)._1(): TiledRasterLayer(
                LayerType.SPACETIME,
                temporal_tiled_raster_layer(
                    option.apply(pyramid.apply(index)._1()),
                    pyramid.apply(index)._2()))
            for index in range(0, pyramid.size())
        }

        image_collection = GeotrellisTimeSeriesImageCollection(
            pyramid=gps.Pyramid(levels),
            service_registry=self._service_registry,
            metadata=metadata)

        if still_needs_band_filter:
            # TODO: avoid this `still_needs_band_filter` ugliness.
            #       Also see https://github.com/Open-EO/openeo-geopyspark-driver/issues/29
            image_collection = image_collection.band_filter(band_indices)

        return image_collection
예제 #25
0
def main():
    import argparse

    logging.basicConfig(level=logging.INFO)
    openeogeotrellis.backend.logger.setLevel(logging.DEBUG)

    handler = logging.StreamHandler(stream=sys.stdout)
    handler.formatter = JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z")

    root_logger = logging.getLogger()
    root_logger.addHandler(handler)

    _log.info("argv: {a!r}".format(a=sys.argv))
    _log.info("ConfigParams(): {c}".format(c=ConfigParams()))

    # FIXME: there's no Java output because Py4J redirects the JVM's stdout/stderr to /dev/null unless JavaGateway's
    #  redirect_stdout/redirect_stderr are set (EP-4018)

    try:
        parser = argparse.ArgumentParser(usage="OpenEO AsyncTask --task <task>",
                                         formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument("--py4j-jarpath", default="venv/share/py4j/py4j0.10.7.jar", help='Path to the Py4J jar')
        parser.add_argument("--py4j-classpath", default="geotrellis-extensions-2.2.0-SNAPSHOT.jar",
                            help='Classpath used to launch the Java Gateway')
        parser.add_argument("--principal", default="*****@*****.**", help="Principal to be used to login to KDC")
        parser.add_argument("--keytab", default="openeo-deploy/mep/openeo.keytab",
                            help="The full path to the file that contains the keytab for the principal")
        parser.add_argument("--task", required=True, dest="task_json", help="The task description in JSON")

        args = parser.parse_args()

        task = json.loads(args.task_json)
        task_id = task['task_id']
        if task_id not in [TASK_DELETE_BATCH_PROCESS_RESULTS, TASK_POLL_SENTINELHUB_BATCH_PROCESSES,
                           TASK_DELETE_BATCH_PROCESS_DEPENDENCY_SOURCES]:
            raise ValueError(f'unsupported task_id "{task_id}"')

        arguments: dict = task.get('arguments', {})

        def batch_jobs() -> GpsBatchJobs:
            java_opts = [
                "-client",
                "-Dsoftware.amazon.awssdk.http.service.impl=software.amazon.awssdk.http.urlconnection.UrlConnectionSdkHttpService"
            ]

            java_gateway = JavaGateway.launch_gateway(jarpath=args.py4j_jarpath,
                                                      classpath=args.py4j_classpath,
                                                      javaopts=java_opts,
                                                      die_on_exit=True)

            return GpsBatchJobs(get_layer_catalog(opensearch_enrich=True), java_gateway.jvm, args.principal,
                                args.keytab)

        if task_id in [TASK_DELETE_BATCH_PROCESS_RESULTS, TASK_DELETE_BATCH_PROCESS_DEPENDENCY_SOURCES]:
            batch_job_id = arguments['batch_job_id']
            dependency_sources = (arguments.get('dependency_sources') or [f"s3://{sentinel_hub.OG_BATCH_RESULTS_BUCKET}/{subfolder}"
                                                                          for subfolder in arguments['subfolders']])

            _log.info(f"removing dependency sources {dependency_sources} for batch job {batch_job_id}...",
                      extra={'job_id': batch_job_id})
            batch_jobs().delete_batch_process_dependency_sources(job_id=batch_job_id,
                                                                 dependency_sources=dependency_sources,
                                                                 propagate_errors=True)
        elif task_id == TASK_POLL_SENTINELHUB_BATCH_PROCESSES:
            batch_job_id = arguments['batch_job_id']
            user_id = arguments['user_id']

            while True:
                time.sleep(SENTINEL_HUB_BATCH_PROCESSES_POLL_INTERVAL_S)

                with JobRegistry() as registry:
                    job_info = registry.get_job(batch_job_id, user_id)

                if job_info.get('dependency_status') not in ['awaiting', "awaiting_retry"]:
                    break
                else:
                    try:
                        batch_jobs().poll_sentinelhub_batch_processes(job_info)
                    except Exception:
                        # TODO: retry in Nifi? How to mark this job as 'error' then?
                        _log.error("failed to handle polling batch processes for batch job {j}:\n{e}"
                                   .format(j=batch_job_id, e=traceback.format_exc()),
                                   extra={'job_id': batch_job_id})

                        with JobRegistry() as registry:
                            registry.set_status(batch_job_id, user_id, 'error')
                            registry.mark_done(batch_job_id, user_id)

                        raise

        else:
            raise AssertionError(f'unexpected task_id "{task_id}"')
    except Exception as e:
        _log.error(e, exc_info=True)
        raise e
예제 #26
0
def main(argv: List[str]) -> None:
    logger.info("argv: {a!r}".format(a=argv))
    logger.info("pid {p}; ppid {pp}; cwd {c}".format(p=os.getpid(),
                                                     pp=os.getppid(),
                                                     c=os.getcwd()))

    if len(argv) < 6:
        print(
            "usage: %s "
            "<job specification input file> <job directory> <results output file name> <user log file name> "
            "<metadata file name> [api version] [dependencies]" % argv[0],
            file=sys.stderr)
        exit(1)

    job_specification_file = argv[1]
    job_dir = Path(argv[2])
    output_file = job_dir / argv[3]
    log_file = job_dir / argv[4]
    metadata_file = job_dir / argv[5]
    api_version = argv[6] if len(argv) >= 7 else None
    dependencies = _deserialize_dependencies(argv[7]) if len(argv) >= 8 else {}
    user_id = argv[8] if len(argv) >= 9 else None

    _create_job_dir(job_dir)

    _setup_user_logging(log_file)

    # Override default temp dir (under CWD). Original default temp dir `/tmp` might be cleaned up unexpectedly.
    temp_dir = Path(os.getcwd()) / "tmp"
    temp_dir.mkdir(parents=True, exist_ok=True)
    logger.info("Using temp dir {t}".format(t=temp_dir))
    os.environ["TMPDIR"] = str(temp_dir)

    try:
        if ConfigParams().is_kube_deploy:
            from openeogeotrellis.utils import s3_client

            bucket = os.environ.get('SWIFT_BUCKET')
            s3_instance = s3_client()

            s3_instance.download_file(bucket,
                                      job_specification_file.strip("/"),
                                      job_specification_file)

        job_specification = _parse(job_specification_file)
        load_custom_processes()

        conf = (SparkConf().set(
            "spark.serializer", "org.apache.spark.serializer.KryoSerializer"
        ).set(
            key='spark.kryo.registrator',
            value='geopyspark.geotools.kryo.ExpandedKryoRegistrator'
        ).set(
            "spark.kryo.classesToRegister",
            "org.openeo.geotrellisaccumulo.SerializableConfiguration,ar.com.hjg.pngj.ImageInfo,ar.com.hjg.pngj.ImageLineInt,geotrellis.raster.RasterRegion$GridBoundsRasterRegion"
        ))

        with SparkContext(conf=conf) as sc:
            principal = sc.getConf().get("spark.yarn.principal")
            key_tab = sc.getConf().get("spark.yarn.keytab")

            kerberos(principal, key_tab)

            def run_driver():
                run_job(job_specification=job_specification,
                        output_file=output_file,
                        metadata_file=metadata_file,
                        api_version=api_version,
                        job_dir=job_dir,
                        dependencies=dependencies,
                        user_id=user_id)

            if sc.getConf().get('spark.python.profile',
                                'false').lower() == 'true':
                # Including the driver in the profiling: a bit hacky solution but spark profiler api does not allow passing args&kwargs
                driver_profile = BasicProfiler(sc)
                driver_profile.profile(run_driver)
                # running the driver code and adding driver's profiling results as "RDD==-1"
                sc.profiler_collector.add_profiler(-1, driver_profile)
                # collect profiles into a zip file
                profile_dumps_dir = job_dir / 'profile_dumps'
                sc.dump_profiles(profile_dumps_dir)

                profile_zip = shutil.make_archive(
                    base_name=str(profile_dumps_dir),
                    format='gztar',
                    root_dir=profile_dumps_dir)
                _add_permissions(Path(profile_zip), stat.S_IWGRP)

                shutil.rmtree(
                    profile_dumps_dir,
                    onerror=lambda func, path, exc_info: logger.warning(
                        f"could not recursively delete {profile_dumps_dir}: {func} {path} failed",
                        exc_info=exc_info))

                logger.info("Saved profiling info to: " + profile_zip)
            else:
                run_driver()

    except Exception as e:
        logger.exception("error processing batch job")
        user_facing_logger.exception("error processing batch job")
        if "Container killed on request. Exit code is 143" in str(e):
            user_facing_logger.error(
                "Your batch job failed because workers used too much Python memory. The same task was attempted multiple times. Consider increasing executor-memoryOverhead or contact the developers to investigate."
            )
        raise e