Ejemplo n.º 1
0
            def evaluate() -> Dict:
                principal = sc.getConf().get("spark.yarn.principal")
                key_tab = sc.getConf().get("spark.yarn.keytab")

                kerberos(principal, key_tab)
                return ProcessGraphDeserializer.evaluate(process_graph,
                                                         env=env)
Ejemplo n.º 2
0
    def load_disk_data(self, format: str, glob_pattern: str, options: dict,
                       viewing_parameters: dict) -> object:
        if format != 'GTiff':
            raise NotImplementedError(
                "The format is not supported by the backend: " + format)

        date_regex = options['date_regex']

        if glob_pattern.startswith("hdfs:"):
            kerberos()

        from_date = normalize_date(viewing_parameters.get("from", None))
        to_date = normalize_date(viewing_parameters.get("to", None))

        left = viewing_parameters.get("left", None)
        right = viewing_parameters.get("right", None)
        top = viewing_parameters.get("top", None)
        bottom = viewing_parameters.get("bottom", None)
        srs = viewing_parameters.get("srs", None)
        band_indices = viewing_parameters.get("bands")

        sc = gps.get_spark_context()

        gateway = JavaGateway(
            eager_load=True, gateway_parameters=sc._gateway.gateway_parameters)
        jvm = gateway.jvm

        extent = jvm.geotrellis.vector.Extent(float(left), float(bottom), float(right), float(top)) \
            if left is not None and right is not None and top is not None and bottom is not None else None

        pyramid = jvm.org.openeo.geotrellis.geotiff.PyramidFactory.from_disk(glob_pattern, date_regex) \
            .pyramid_seq(extent, srs, from_date, to_date)

        temporal_tiled_raster_layer = jvm.geopyspark.geotrellis.TemporalTiledRasterLayer
        option = jvm.scala.Option
        levels = {
            pyramid.apply(index)._1(): TiledRasterLayer(
                LayerType.SPACETIME,
                temporal_tiled_raster_layer(
                    option.apply(pyramid.apply(index)._1()),
                    pyramid.apply(index)._2()))
            for index in range(0, pyramid.size())
        }

        image_collection = GeotrellisTimeSeriesImageCollection(
            pyramid=gps.Pyramid(levels),
            service_registry=self._service_registry,
            metadata={})

        return image_collection.band_filter(
            band_indices) if band_indices else image_collection
Ejemplo n.º 3
0
 def evaluate() -> Dict:
     kerberos()
     return ProcessGraphDeserializer.evaluate(
         process_graph, viewingParameters={'version': "0.4.0"})
Ejemplo n.º 4
0
def main(argv: List[str]) -> None:
    logger.debug("argv: {a!r}".format(a=argv))

    if len(argv) < 4:
        print(
            "usage: %s <job specification input file> <results output file> <user log file> [api version]"
            % argv[0],
            file=sys.stderr)
        exit(1)

    job_specification_file, output_file, log_file = argv[1], argv[2], argv[3]
    api_version = argv[4] if len(argv) == 5 else None

    _setup_user_logging(log_file)

    try:
        job_specification = _parse(job_specification_file)
        viewing_parameters = {'version': api_version} if api_version else None
        process_graph = job_specification['process_graph']

        try:
            import custom_processes
        except ImportError:
            logger.info('No custom_processes.py found.')

        with SparkContext.getOrCreate():
            kerberos()
            result = ProcessGraphDeserializer.evaluate(process_graph,
                                                       viewing_parameters)
            logger.info(
                "Evaluated process graph result of type {t}: {r!r}".format(
                    t=type(result), r=result))

            if isinstance(result, ImageCollection):
                format_options = job_specification.get('output', {})
                result.download(output_file,
                                bbox="",
                                time="",
                                **format_options)
                logger.info("wrote image collection to %s" % output_file)
            elif isinstance(result, ImageCollectionResult):
                result.imagecollection.download(output_file,
                                                bbox="",
                                                time="",
                                                format=result.format,
                                                **result.options)
                logger.info("wrote image collection to %s" % output_file)
            elif isinstance(result, JSONResult):
                with open(output_file, 'w') as f:
                    json.dump(result.prepare_for_json(), f)
                logger.info("wrote JSON result to %s" % output_file)
            elif isinstance(result, MultipleFilesResult):
                result.reduce(output_file, delete_originals=True)
                logger.info("reduced %d files to %s" %
                            (len(result.files), output_file))
            else:
                with open(output_file, 'w') as f:
                    json.dump(result, f)
                logger.info("wrote JSON result to %s" % output_file)
    except Exception as e:
        logger.exception("error processing batch job")
        user_facing_logger.exception("error processing batch job")
        raise e
Ejemplo n.º 5
0
    def start_job(self, job_id: str, user_id: str):
        from pyspark import SparkContext

        with JobRegistry() as registry:
            job_info = registry.get_job(job_id, user_id)
            api_version = job_info.get('api_version')

            current_status = job_info['status']
            if current_status in ['queued', 'running']:
                return
            elif current_status != 'created':
                # TODO: is this about restarting a job?
                registry.mark_ongoing(job_id, user_id)
                registry.set_application_id(job_id, user_id, None)
                registry.set_status(job_id, user_id, 'created')

            spec = json.loads(job_info.get('specification'))
            extra_options = spec.get('job_options', {})

            driver_memory = extra_options.get("driver-memory", "22G")
            executor_memory = extra_options.get("executor-memory", "5G")

            kerberos()

            output_dir = self._get_job_output_dir(job_id)
            input_file = output_dir / "in"
            # TODO: how support multiple output files?
            output_file = output_dir / "out"
            log_file = output_dir / "log"

            with input_file.open('w') as f:
                f.write(job_info['specification'])

            conf = SparkContext.getOrCreate().getConf()
            principal, key_tab = conf.get("spark.yarn.principal"), conf.get(
                "spark.yarn.keytab")

            script_location = pkg_resources.resource_filename(
                'openeogeotrellis.deploy', 'submit_batch_job.sh')

            args = [
                script_location,
                "OpenEO batch job {j} user {u}".format(j=job_id, u=user_id),
                str(input_file),
                str(output_file),
                str(log_file)
            ]

            if principal is not None and key_tab is not None:
                args.append(principal)
                args.append(key_tab)
            else:
                args.append("no_principal")
                args.append("no_keytab")
            if api_version:
                args.append(api_version)
            else:
                args.append("0.4.0")

            args.append(driver_memory)
            args.append(executor_memory)

            try:
                output_string = subprocess.check_output(
                    args, stderr=subprocess.STDOUT, universal_newlines=True)
            except CalledProcessError as e:
                logger.exception(e)
                logger.error(e.stdout)
                logger.error(e.stderr)
                raise e

            try:
                # note: a job_id is returned as soon as an application ID is found in stderr, not when the job is finished
                logger.info(output_string)
                application_id = self._extract_application_id(output_string)
                print("mapped job_id %s to application ID %s" %
                      (job_id, application_id))

                registry.set_application_id(job_id, user_id, application_id)
            except _BatchJobError as e:
                traceback.print_exc(file=sys.stderr)
                # TODO: why reraise as CalledProcessError?
                raise CalledProcessError(1, str(args), output=output_string)
Ejemplo n.º 6
0
    def load_collection(
            self, collection_id: str,
            viewing_parameters: dict) -> 'GeotrellisTimeSeriesImageCollection':
        logger.info("Creating layer for {c} with viewingParameters {v}".format(
            c=collection_id, v=viewing_parameters))

        # TODO is it necessary to do this kerberos stuff here?
        kerberos()

        metadata = CollectionMetadata(
            self.get_collection_metadata(collection_id, strip_private=False))
        layer_source_info = metadata.get("_vito", "data_source", default={})
        layer_source_type = layer_source_info.get("type", "Accumulo").lower()
        logger.info("Layer source type: {s!r}".format(s=layer_source_type))

        import geopyspark as gps
        from_date = normalize_date(viewing_parameters.get("from", None))
        to_date = normalize_date(viewing_parameters.get("to", None))

        left = viewing_parameters.get("left", None)
        right = viewing_parameters.get("right", None)
        top = viewing_parameters.get("top", None)
        bottom = viewing_parameters.get("bottom", None)
        srs = viewing_parameters.get("srs", None)
        bands = viewing_parameters.get("bands", None)
        band_indices = [metadata.get_band_index(b)
                        for b in bands] if bands else None
        logger.info("band_indices: {b!r}".format(b=band_indices))
        # TODO: avoid this `still_needs_band_filter` ugliness.
        #       Also see https://github.com/Open-EO/openeo-geopyspark-driver/issues/29
        still_needs_band_filter = False
        pysc = gps.get_spark_context()
        extent = None

        gateway = JavaGateway(
            eager_load=True,
            gateway_parameters=pysc._gateway.gateway_parameters)
        jvm = gateway.jvm

        spatial_bounds_present = left is not None and right is not None and top is not None and bottom is not None

        if spatial_bounds_present:
            extent = jvm.geotrellis.vector.Extent(float(left), float(bottom),
                                                  float(right), float(top))
        elif ConfigParams().require_bounds:
            raise ProcessGraphComplexityException
        else:
            srs = "EPSG:4326"
            extent = jvm.geotrellis.vector.Extent(-180.0, -90.0, 180.0, 90.0)

        def accumulo_pyramid():
            pyramidFactory = jvm.org.openeo.geotrellisaccumulo.PyramidFactory(
                "hdp-accumulo-instance",
                ','.join(ConfigParams().zookeepernodes))
            if layer_source_info.get("split", False):
                pyramidFactory.setSplitRanges(True)

            accumulo_layer_name = layer_source_info['data_id']
            nonlocal still_needs_band_filter
            still_needs_band_filter = bool(band_indices)
            return pyramidFactory.pyramid_seq(accumulo_layer_name, extent, srs,
                                              from_date, to_date)

        def s3_pyramid():
            endpoint = layer_source_info['endpoint']
            region = layer_source_info['region']
            bucket_name = layer_source_info['bucket_name']
            nonlocal still_needs_band_filter
            still_needs_band_filter = bool(band_indices)
            return jvm.org.openeo.geotrelliss3.PyramidFactory(endpoint, region, bucket_name) \
                .pyramid_seq(extent, srs, from_date, to_date)

        def s3_jp2_pyramid():
            endpoint = layer_source_info['endpoint']
            region = layer_source_info['region']

            return jvm.org.openeo.geotrelliss3.Jp2PyramidFactory(endpoint, region) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def file_s2_radiometry_pyramid():
            return jvm.org.openeo.geotrellis.file.Sentinel2RadiometryPyramidFactory() \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def file_s2_pyramid():
            oscars_collection_id = layer_source_info['oscars_collection_id']
            oscars_link_titles = metadata.band_names
            root_path = layer_source_info['root_path']

            filtered_link_titles = [
                oscars_link_titles[i] for i in band_indices
            ] if band_indices else oscars_link_titles

            return jvm.org.openeo.geotrellis.file.Sentinel2PyramidFactory(
                oscars_collection_id, filtered_link_titles,
                root_path).pyramid_seq(extent, srs, from_date, to_date)

        def geotiff_pyramid():
            glob_pattern = layer_source_info['glob_pattern']
            date_regex = layer_source_info['date_regex']

            new_pyramid_factory = jvm.org.openeo.geotrellis.geotiff.PyramidFactory.from_disk(
                glob_pattern, date_regex)

            return self._geotiff_pyramid_factories.setdefault(collection_id, new_pyramid_factory) \
                .pyramid_seq(extent, srs, from_date, to_date)

        def file_s1_coherence_pyramid():
            return jvm.org.openeo.geotrellis.file.Sentinel1CoherencePyramidFactory() \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def sentinel_hub_s1_pyramid():
            return jvm.org.openeo.geotrellissentinelhub.S1PyramidFactory(layer_source_info.get('uuid')) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def sentinel_hub_s2_l1c_pyramid():
            return jvm.org.openeo.geotrellissentinelhub.S2L1CPyramidFactory(layer_source_info.get('uuid')) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def sentinel_hub_s2_l2a_pyramid():
            return jvm.org.openeo.geotrellissentinelhub.S2L2APyramidFactory(layer_source_info.get('uuid')) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        def sentinel_hub_l8_pyramid():
            return jvm.org.openeo.geotrellissentinelhub.L8PyramidFactory(layer_source_info.get('uuid')) \
                .pyramid_seq(extent, srs, from_date, to_date, band_indices)

        logger.info("loading pyramid {s}".format(s=layer_source_type))
        if layer_source_type == 's3':
            pyramid = s3_pyramid()
        elif layer_source_type == 's3-jp2':
            pyramid = s3_jp2_pyramid()
        elif layer_source_type == 'file-s2-radiometry':
            pyramid = file_s2_radiometry_pyramid()
        elif layer_source_type == 'file-s2':
            pyramid = file_s2_pyramid()
        elif layer_source_type == 'geotiff':
            pyramid = geotiff_pyramid()
        elif layer_source_type == 'file-s1-coherence':
            pyramid = file_s1_coherence_pyramid()
        elif layer_source_type == 'sentinel-hub-s1':
            pyramid = sentinel_hub_s1_pyramid()
        elif layer_source_type == 'sentinel-hub-s2-l1c':
            pyramid = sentinel_hub_s2_l1c_pyramid()
        elif layer_source_type == 'sentinel-hub-s2-l2a':
            pyramid = sentinel_hub_s2_l2a_pyramid()
        elif layer_source_type == 'sentinel-hub-l8':
            pyramid = sentinel_hub_l8_pyramid()
        else:
            pyramid = accumulo_pyramid()

        temporal_tiled_raster_layer = jvm.geopyspark.geotrellis.TemporalTiledRasterLayer
        option = jvm.scala.Option
        levels = {
            pyramid.apply(index)._1(): TiledRasterLayer(
                LayerType.SPACETIME,
                temporal_tiled_raster_layer(
                    option.apply(pyramid.apply(index)._1()),
                    pyramid.apply(index)._2()))
            for index in range(0, pyramid.size())
        }

        image_collection = GeotrellisTimeSeriesImageCollection(
            pyramid=gps.Pyramid(levels),
            service_registry=self._service_registry,
            metadata=metadata)

        if still_needs_band_filter:
            # TODO: avoid this `still_needs_band_filter` ugliness.
            #       Also see https://github.com/Open-EO/openeo-geopyspark-driver/issues/29
            image_collection = image_collection.band_filter(band_indices)

        return image_collection
Ejemplo n.º 7
0
def main(argv: List[str]) -> None:
    logger.info("argv: {a!r}".format(a=argv))
    logger.info("pid {p}; ppid {pp}; cwd {c}".format(p=os.getpid(),
                                                     pp=os.getppid(),
                                                     c=os.getcwd()))

    if len(argv) < 6:
        print(
            "usage: %s "
            "<job specification input file> <job directory> <results output file name> <user log file name> "
            "<metadata file name> [api version] [dependencies]" % argv[0],
            file=sys.stderr)
        exit(1)

    job_specification_file = argv[1]
    job_dir = Path(argv[2])
    output_file = job_dir / argv[3]
    log_file = job_dir / argv[4]
    metadata_file = job_dir / argv[5]
    api_version = argv[6] if len(argv) >= 7 else None
    dependencies = _deserialize_dependencies(argv[7]) if len(argv) >= 8 else {}
    user_id = argv[8] if len(argv) >= 9 else None

    _create_job_dir(job_dir)

    _setup_user_logging(log_file)

    # Override default temp dir (under CWD). Original default temp dir `/tmp` might be cleaned up unexpectedly.
    temp_dir = Path(os.getcwd()) / "tmp"
    temp_dir.mkdir(parents=True, exist_ok=True)
    logger.info("Using temp dir {t}".format(t=temp_dir))
    os.environ["TMPDIR"] = str(temp_dir)

    try:
        if ConfigParams().is_kube_deploy:
            from openeogeotrellis.utils import s3_client

            bucket = os.environ.get('SWIFT_BUCKET')
            s3_instance = s3_client()

            s3_instance.download_file(bucket,
                                      job_specification_file.strip("/"),
                                      job_specification_file)

        job_specification = _parse(job_specification_file)
        load_custom_processes()

        conf = (SparkConf().set(
            "spark.serializer", "org.apache.spark.serializer.KryoSerializer"
        ).set(
            key='spark.kryo.registrator',
            value='geopyspark.geotools.kryo.ExpandedKryoRegistrator'
        ).set(
            "spark.kryo.classesToRegister",
            "org.openeo.geotrellisaccumulo.SerializableConfiguration,ar.com.hjg.pngj.ImageInfo,ar.com.hjg.pngj.ImageLineInt,geotrellis.raster.RasterRegion$GridBoundsRasterRegion"
        ))

        with SparkContext(conf=conf) as sc:
            principal = sc.getConf().get("spark.yarn.principal")
            key_tab = sc.getConf().get("spark.yarn.keytab")

            kerberos(principal, key_tab)

            def run_driver():
                run_job(job_specification=job_specification,
                        output_file=output_file,
                        metadata_file=metadata_file,
                        api_version=api_version,
                        job_dir=job_dir,
                        dependencies=dependencies,
                        user_id=user_id)

            if sc.getConf().get('spark.python.profile',
                                'false').lower() == 'true':
                # Including the driver in the profiling: a bit hacky solution but spark profiler api does not allow passing args&kwargs
                driver_profile = BasicProfiler(sc)
                driver_profile.profile(run_driver)
                # running the driver code and adding driver's profiling results as "RDD==-1"
                sc.profiler_collector.add_profiler(-1, driver_profile)
                # collect profiles into a zip file
                profile_dumps_dir = job_dir / 'profile_dumps'
                sc.dump_profiles(profile_dumps_dir)

                profile_zip = shutil.make_archive(
                    base_name=str(profile_dumps_dir),
                    format='gztar',
                    root_dir=profile_dumps_dir)
                _add_permissions(Path(profile_zip), stat.S_IWGRP)

                shutil.rmtree(
                    profile_dumps_dir,
                    onerror=lambda func, path, exc_info: logger.warning(
                        f"could not recursively delete {profile_dumps_dir}: {func} {path} failed",
                        exc_info=exc_info))

                logger.info("Saved profiling info to: " + profile_zip)
            else:
                run_driver()

    except Exception as e:
        logger.exception("error processing batch job")
        user_facing_logger.exception("error processing batch job")
        if "Container killed on request. Exit code is 143" in str(e):
            user_facing_logger.error(
                "Your batch job failed because workers used too much Python memory. The same task was attempted multiple times. Consider increasing executor-memoryOverhead or contact the developers to investigate."
            )
        raise e