예제 #1
0
 def _request(self, url, info_message=None, exception_message=None):
     try:
         if info_message:
             logger.info(info_message)
         response = requests.post(url, json=self.query_params)
         response.raise_for_status()
     except (requests.HTTPError, urllib_HTTPError) as err:
         # check if error is identified as auth_error in provider conf
         auth_errors = getattr(self.config, "auth_error_code", [None])
         if not isinstance(auth_errors, list):
             auth_errors = [auth_errors]
         if err.response.status_code in auth_errors:
             raise AuthenticationError(
                 "HTTP Error {} returned:\n{}\nPlease check your credentials for {}".format(
                     err.response.status_code,
                     err.response.text.strip(),
                     self.provider,
                 )
             )
         if exception_message:
             logger.exception(exception_message)
         else:
             logger.exception(
                 "Skipping error while requesting: %s (provider:%s, plugin:%s):",
                 url,
                 self.provider,
                 self.__class__.__name__,
             )
         raise RequestError(str(err))
     return response
예제 #2
0
    def authenticate(self):
        """
        Makes authentication request
        """
        self.validate_config_credentials()
        response = self.session.post(
            self.TOKEN_URL_TEMPLATE.format(
                auth_base_uri=self.config.auth_base_uri.rstrip("/"),
                realm=self.config.realm,
            ),
            data={
                "client_id": self.config.client_id,
                "client_secret": self.config.client_secret,
                "username": self.config.credentials["username"],
                "password": self.config.credentials["password"],
                "grant_type": self.GRANT_TYPE,
            },
        )
        try:
            response.raise_for_status()
        except requests.HTTPError as e:
            # check if error is identified as auth_error in provider conf
            auth_errors = getattr(self.config, "auth_error_code", [None])
            if not isinstance(auth_errors, list):
                auth_errors = [auth_errors]
            if e.response.status_code in auth_errors:
                raise AuthenticationError(
                    "HTTP Error %s returned, %s\nPlease check your credentials for %s"
                    % (e.response.status_code, e.response.text.strip(),
                       self.provider))
            # other error
            else:
                import traceback as tb

                raise AuthenticationError(
                    "Something went wrong while trying to get access token:\n{}"
                    .format(tb.format_exc()))
        return CodeAuthorizedAuth(
            response.json()["access_token"],
            self.config.token_provision,
            key=getattr(self.config, "token_qs_key", None),
        )
예제 #3
0
    def get_authenticated_objects(self, bucket_name, prefix, auth_dict):
        """Get boto3 authenticated objects for the given bucket using
        the most adapted auth strategy

        :param bucket_name: bucket containg objects
        :type bucket_name: str
        :param prefix: prefix used to filter objects on auth try
                       (not used to filter returned objects)
        :type prefix: str
        :param auth_dict: dictionnary containing authentication keys
        :type auth_dict: dict
        :return: boto3 authenticated objects
        :rtype: :class:`~boto3.resources.collection.s3.Bucket.objectsCollection`
        """
        auth_methods = [
            self._get_authenticated_objects_unsigned,
            self._get_authenticated_objects_from_auth_profile,
            self._get_authenticated_objects_from_auth_keys,
            self._get_authenticated_objects_from_env,
        ]
        # skip _get_authenticated_objects_from_env if credentials were filled in eodag conf
        if auth_dict:
            del auth_methods[-1]

        for try_auth_method in auth_methods:
            try:
                s3_objects = try_auth_method(bucket_name, prefix, auth_dict)
                if s3_objects:
                    logger.debug("Auth using %s succeeded",
                                 try_auth_method.__name__)
                    return s3_objects
            except ClientError as e:
                if e.response.get("Error", {}).get("Code", {}) in [
                        "AccessDenied",
                        "InvalidAccessKeyId",
                        "SignatureDoesNotMatch",
                ]:
                    pass
                else:
                    raise e
            except ProfileNotFound:
                pass
            logger.debug("Auth using %s failed", try_auth_method.__name__)

        raise AuthenticationError(
            "Unable do authenticate on s3://%s using any available credendials configuration"
            % bucket_name)
예제 #4
0
    def authenticate(self):
        """Authenticate"""
        state = self.compute_state()
        authentication_response = self.authenticate_user(state)
        exchange_url = authentication_response.url
        if self.config.user_consent_needed:
            user_consent_response = self.grant_user_consent(
                authentication_response)
            exchange_url = user_consent_response.url
        try:
            token = self.exchange_code_for_token(exchange_url, state)
        except Exception:
            import traceback as tb

            raise AuthenticationError(
                "Something went wrong while trying to get authorization token:\n{}"
                .format(tb.format_exc()))
        return CodeAuthorizedAuth(
            token,
            self.config.token_provision,
            key=getattr(self.config, "token_qs_key", None),
        )
예제 #5
0
 def exchange_code_for_token(self, authorized_url, state):
     """Get exchange code for token"""
     qs = parse_qs(urlparse(authorized_url).query)
     if qs["state"][0] != state:
         raise AuthenticationError(
             "The state received in the authorized url does not match initially computed state"
         )
     code = qs["code"][0]
     token_exchange_data = {
         "redirect_uri": self.config.redirect_uri,
         "client_id": self.config.client_id,
         "code": code,
         "state": state,
     }
     # If necessary, change the keys of the form data that will be passed to the token exchange POST request
     custom_token_exchange_params = getattr(self.config,
                                            "token_exchange_params", {})
     if custom_token_exchange_params:
         token_exchange_data[custom_token_exchange_params[
             "redirect_uri"]] = token_exchange_data.pop("redirect_uri")
         token_exchange_data[custom_token_exchange_params[
             "client_id"]] = token_exchange_data.pop("client_id")
     # If the client_secret is known, the token exchange request must be authenticated with a BASIC Auth, using the
     # client_id and client_secret as username and password respectively
     if getattr(self.config, "client_secret", None):
         token_exchange_data.update({
             "auth": (self.config.client_id, self.config.client_secret),
             "grant_type":
             "authorization_code",
             "client_secret":
             self.config.client_secret,
         })
     post_request_kwargs = {
         self.config.token_exchange_post_data_method: token_exchange_data
     }
     r = self.session.post(self.config.token_uri, **post_request_kwargs)
     return r.json()[self.config.token_key]
예제 #6
0
파일: http.py 프로젝트: apparell/eodag
    def download(
        self,
        product,
        auth=None,
        progress_callback=None,
        wait=DEFAULT_DOWNLOAD_WAIT,
        timeout=DEFAULT_DOWNLOAD_TIMEOUT,
        **kwargs
    ):
        """Download a product using HTTP protocol.

        The downloaded product is assumed to be a Zip file. If it is not,
        the user is warned, it is renamed to remove the zip extension and
        no further treatment is done (no extraction)
        """
        fs_path, record_filename = self._prepare_download(product, **kwargs)
        if not fs_path or not record_filename:
            return fs_path

        # progress bar init
        if progress_callback is None:
            progress_callback = get_progress_callback()
        progress_callback.desc = product.properties.get("id", "")
        progress_callback.position = 1

        # download assets if exist instead of remote_location
        try:
            return self._download_assets(
                product,
                fs_path.replace(".zip", ""),
                record_filename,
                auth,
                progress_callback,
                **kwargs
            )
        except NotAvailableError:
            pass

        url = product.remote_location

        # order product if it is offline
        ordered_message = ""
        if (
            "orderLink" in product.properties
            and "storageStatus" in product.properties
            and product.properties["storageStatus"] == OFFLINE_STATUS
        ):
            order_method = getattr(self.config, "order_method", "GET")
            with requests.request(
                method=order_method,
                url=product.properties["orderLink"],
                auth=auth,
                headers=getattr(self.config, "order_headers", {}),
            ) as response:
                try:
                    response.raise_for_status()
                    ordered_message = response.text
                    logger.debug(ordered_message)
                except HTTPError as e:
                    logger.warning(
                        "%s could not be ordered, request returned %s",
                        product.properties["title"],
                        e,
                    )

        # initiate retry loop
        start_time = datetime.now()
        stop_time = datetime.now() + timedelta(minutes=timeout)
        product.next_try = start_time
        retry_count = 0
        not_available_info = "The product could not be downloaded"
        # another output for notebooks
        nb_info = NotebookWidgets()

        while "Loop until products download succeeds or timeout is reached":

            if datetime.now() >= product.next_try:
                product.next_try += timedelta(minutes=wait)
                try:
                    params = kwargs.pop("dl_url_params", None) or getattr(
                        self.config, "dl_url_params", {}
                    )
                    with requests.get(
                        url,
                        stream=True,
                        auth=auth,
                        params=params,
                    ) as stream:
                        try:
                            stream.raise_for_status()
                        except HTTPError as e:
                            # check if error is identified as auth_error in provider conf
                            auth_errors = getattr(
                                self.config, "auth_error_code", [None]
                            )
                            if not isinstance(auth_errors, list):
                                auth_errors = [auth_errors]
                            if e.response.status_code in auth_errors:
                                raise AuthenticationError(
                                    "HTTP Error %s returned, %s\nPlease check your credentials for %s"
                                    % (
                                        e.response.status_code,
                                        e.response.text.strip(),
                                        self.provider,
                                    )
                                )
                            # product not available
                            elif (
                                product.properties.get("storageStatus", ONLINE_STATUS)
                                != ONLINE_STATUS
                            ):
                                msg = (
                                    ordered_message
                                    if ordered_message and not e.response.text
                                    else e.response.text
                                )
                                raise NotAvailableError(
                                    "%s(initially %s) requested, returned: %s"
                                    % (
                                        product.properties["title"],
                                        product.properties["storageStatus"],
                                        msg,
                                    )
                                )
                            else:
                                import traceback as tb

                                logger.error(
                                    "Error while getting resource :\n%s",
                                    tb.format_exc(),
                                )
                        else:
                            stream_size = int(stream.headers.get("content-length", 0))
                            if (
                                stream_size == 0
                                and "storageStatus" in product.properties
                                and product.properties["storageStatus"] != ONLINE_STATUS
                            ):
                                raise NotAvailableError(
                                    "%s(initially %s) ordered, got: %s"
                                    % (
                                        product.properties["title"],
                                        product.properties["storageStatus"],
                                        stream.reason,
                                    )
                                )
                            progress_callback.max_size = stream_size
                            progress_callback.reset()
                            with open(fs_path, "wb") as fhandle:
                                for chunk in stream.iter_content(chunk_size=64 * 1024):
                                    if chunk:
                                        fhandle.write(chunk)
                                        progress_callback(len(chunk), stream_size)

                            with open(record_filename, "w") as fh:
                                fh.write(url)
                            logger.debug("Download recorded in %s", record_filename)

                            # Check that the downloaded file is really a zip file
                            if not zipfile.is_zipfile(fs_path):
                                logger.warning(
                                    "Downloaded product is not a Zip File. Please check its file type before using it"
                                )
                                new_fs_path = fs_path[: fs_path.index(".zip")]
                                shutil.move(fs_path, new_fs_path)
                                return new_fs_path
                            return self._finalize(fs_path, **kwargs)

                except NotAvailableError as e:
                    if not getattr(self.config, "order_enabled", False):
                        raise NotAvailableError(
                            "Product is not available for download and order is not supported for %s, %s"
                            % (self.provider, e)
                        )
                    not_available_info = e
                    pass

            if datetime.now() < product.next_try and datetime.now() < stop_time:
                wait_seconds = (product.next_try - datetime.now()).seconds
                retry_count += 1
                retry_info = (
                    "[Retry #%s] Waiting %ss until next download try (retry every %s' for %s')"
                    % (retry_count, wait_seconds, wait, timeout)
                )
                logger.debug(not_available_info)
                # Retry-After info from Response header
                retry_server_info = stream.headers.get("Retry-After", "")
                if retry_server_info:
                    logger.debug(
                        "[%s response] Retry-After: %s"
                        % (self.provider, retry_server_info)
                    )
                logger.info(retry_info)
                nb_info.display_html(retry_info)
                sleep(wait_seconds + 1)
            elif datetime.now() >= stop_time and timeout > 0:
                if "storageStatus" not in product.properties:
                    product.properties["storageStatus"] = "N/A status"
                logger.info(not_available_info)
                raise NotAvailableError(
                    "%s is not available (%s) and could not be downloaded, timeout reached"
                    % (product.properties["title"], product.properties["storageStatus"])
                )
            elif datetime.now() >= stop_time:
                raise NotAvailableError(not_available_info)
예제 #7
0
파일: http.py 프로젝트: apparell/eodag
    def _download_assets(
        self,
        product,
        fs_dir_path,
        record_filename,
        auth=None,
        progress_callback=None,
        **kwargs
    ):
        """Download product assets if they exist"""
        assets_urls = [
            a["href"] for a in getattr(product, "assets", {}).values() if "href" in a
        ]

        if not assets_urls:
            raise NotAvailableError("No assets available for %s" % product)

        # remove existing incomplete file
        if os.path.isfile(fs_dir_path):
            os.remove(fs_dir_path)
        # create product dest dir
        if not os.path.isdir(fs_dir_path):
            os.makedirs(fs_dir_path)

        # product conf overrides provider conf for "flatten_top_dirs"
        product_conf = getattr(self.config, "products", {}).get(
            product.product_type, {}
        )
        flatten_top_dirs = product_conf.get(
            "flatten_top_dirs", getattr(self.config, "flatten_top_dirs", False)
        )

        total_size = sum(
            [
                int(
                    requests.head(asset_url, auth=auth).headers.get("Content-length", 0)
                )
                for asset_url in assets_urls
            ]
        )
        progress_callback.max_size = total_size
        progress_callback.reset()
        error_messages = set()

        for asset_url in assets_urls:

            params = kwargs.pop("dl_url_params", None) or getattr(
                self.config, "dl_url_params", {}
            )
            with requests.get(
                asset_url,
                stream=True,
                auth=auth,
                params=params,
            ) as stream:
                try:
                    stream.raise_for_status()
                except HTTPError as e:
                    # check if error is identified as auth_error in provider conf
                    auth_errors = getattr(self.config, "auth_error_code", [None])
                    if not isinstance(auth_errors, list):
                        auth_errors = [auth_errors]
                    if e.response.status_code in auth_errors:
                        raise AuthenticationError(
                            "HTTP Error %s returned, %s\nPlease check your credentials for %s"
                            % (
                                e.response.status_code,
                                e.response.text.strip(),
                                self.provider,
                            )
                        )
                    else:
                        logger.warning("Unexpected error: %s" % e)
                        logger.warning("Skipping %s" % asset_url)
                    error_messages.add(str(e))
                else:
                    asset_rel_path = (
                        asset_url.replace(product.location, "")
                        .replace("https://", "")
                        .replace("http://", "")
                    )
                    asset_abs_path = os.path.join(fs_dir_path, asset_rel_path)
                    asset_abs_path_dir = os.path.dirname(asset_abs_path)
                    if not os.path.isdir(asset_abs_path_dir):
                        os.makedirs(asset_abs_path_dir)

                    if not os.path.isfile(asset_abs_path):
                        with open(asset_abs_path, "wb") as fhandle:
                            for chunk in stream.iter_content(chunk_size=64 * 1024):
                                if chunk:
                                    fhandle.write(chunk)
                                    progress_callback(len(chunk))

        # could not download any file
        if len(os.listdir(fs_dir_path)) == 0:
            raise HTTPError(", ".join(error_messages))

        # flatten directory structure
        if flatten_top_dirs:
            tmp_product_local_path = "%s-tmp" % fs_dir_path
            for d, dirs, files in os.walk(fs_dir_path):
                if len(files) != 0:
                    shutil.copytree(d, tmp_product_local_path)
                    shutil.rmtree(fs_dir_path)
                    os.rename(tmp_product_local_path, fs_dir_path)
                    break

        # save hash/record file
        with open(record_filename, "w") as fh:
            fh.write(product.remote_location)
        logger.debug("Download recorded in %s", record_filename)

        return fs_dir_path
예제 #8
0
파일: usgs.py 프로젝트: ahuarte47/eodag
    def query(self,
              product_type=None,
              items_per_page=None,
              page=None,
              count=True,
              **kwargs):
        """Search for data on USGS catalogues

        .. versionchanged::
           2.2.0

                * Based on usgs library v0.3.0 which now uses M2M API. The library
                  is used for both search & download

        .. versionchanged::
            1.0

                * ``product_type`` is no longer mandatory
        """
        product_type = kwargs.get("productType")
        if product_type is None:
            return [], 0
        try:
            api.login(
                self.config.credentials["username"],
                self.config.credentials["password"],
                save=True,
            )
        except USGSError:
            raise AuthenticationError(
                "Please check your USGS credentials.") from None

        product_type_def_params = self.config.products.get(
            product_type, self.config.products[GENERIC_PRODUCT_TYPE])
        usgs_dataset = format_dict_items(product_type_def_params,
                                         **kwargs)["dataset"]
        start_date = kwargs.pop("startTimeFromAscendingNode", None)
        end_date = kwargs.pop("completionTimeFromAscendingNode", None)
        geom = kwargs.pop("geometry", None)
        footprint = {}
        if hasattr(geom, "bounds"):
            (
                footprint["lonmin"],
                footprint["latmin"],
                footprint["lonmax"],
                footprint["latmax"],
            ) = geom.bounds
        else:
            footprint = geom

        final = []
        if footprint and len(footprint.keys()) == 4:  # a rectangle (or bbox)
            lower_left = {
                "longitude": footprint["lonmin"],
                "latitude": footprint["latmin"],
            }
            upper_right = {
                "longitude": footprint["lonmax"],
                "latitude": footprint["latmax"],
            }
        else:
            lower_left, upper_right = None, None
        try:
            results = api.scene_search(
                usgs_dataset,
                start_date=start_date,
                end_date=end_date,
                ll=lower_left,
                ur=upper_right,
                max_results=items_per_page,
                starting_number=(1 + (page - 1) * items_per_page),
            )

            # Same method as in base.py, Search.__init__()
            # Prepare the metadata mapping
            # Do a shallow copy, the structure is flat enough for this to be sufficient
            metas = DEFAULT_METADATA_MAPPING.copy()
            # Update the defaults with the mapping value. This will add any new key
            # added by the provider mapping that is not in the default metadata.
            # A deepcopy is done to prevent self.config.metadata_mapping from being modified when metas[metadata]
            # is a list and is modified
            metas.update(copy.deepcopy(self.config.metadata_mapping))
            metas = mtd_cfg_as_jsonpath(metas)

            for result in results["data"]["results"]:

                result["productType"] = usgs_dataset

                product_properties = properties_from_json(result, metas)

                final.append(
                    EOProduct(
                        productType=product_type,
                        provider=self.provider,
                        properties=product_properties,
                        geometry=footprint,
                    ))
        except USGSError as e:
            logger.warning(
                "Product type %s does not exist on USGS EE catalog",
                usgs_dataset,
            )
            logger.warning("Skipping error: %s", e)
        api.logout()

        if final:
            # parse total_results
            path_parsed = parse(
                self.config.pagination["total_items_nb_key_path"])
            total_results = path_parsed.find(results["data"])[0].value
        else:
            total_results = 0

        return final, total_results
예제 #9
0
파일: usgs.py 프로젝트: ahuarte47/eodag
    def download(self, product, auth=None, progress_callback=None, **kwargs):
        """Download data from USGS catalogues"""

        fs_path, record_filename = self._prepare_download(
            product, outputs_extension=".tar.gz", **kwargs)
        if not fs_path or not record_filename:
            return fs_path

        # progress bar init
        if progress_callback is None:
            progress_callback = get_progress_callback()
        progress_callback.desc = product.properties.get("id", "")
        progress_callback.position = 1

        try:
            api.login(
                self.config.credentials["username"],
                self.config.credentials["password"],
                save=True,
            )
        except USGSError:
            raise AuthenticationError(
                "Please check your USGS credentials.") from None

        download_options = api.download_options(
            product.properties["productType"], product.properties["id"])

        try:
            product_ids = [
                p["id"] for p in download_options["data"]
                if p["downloadSystem"] == "dds"
            ]
        except KeyError as e:
            raise NotAvailableError("%s not found in %s's products" %
                                    (e, product.properties["id"]))

        if not product_ids:
            raise NotAvailableError("No USGS products found for %s" %
                                    product.properties["id"])

        req_urls = []
        for product_id in product_ids:
            download_request = api.download_request(
                product.properties["productType"], product.properties["id"],
                product_id)
            try:
                req_urls.extend([
                    x["url"]
                    for x in download_request["data"]["preparingDownloads"]
                ])
            except KeyError as e:
                raise NotAvailableError("%s not found in %s download_request" %
                                        (e, product.properties["id"]))

        if len(req_urls) > 1:
            logger.warning(
                "%s usgs products found for %s. Only first will be downloaded"
                % (len(req_urls), product.properties["id"]))
        elif not req_urls:
            raise NotAvailableError("No usgs request url was found for %s" %
                                    product.properties["id"])

        req_url = req_urls[0]
        progress_callback.reset()
        with requests.get(
                req_url,
                stream=True,
        ) as stream:
            try:
                stream.raise_for_status()
            except HTTPError:
                import traceback as tb

                logger.error(
                    "Error while getting resource :\n%s",
                    tb.format_exc(),
                )
            else:
                stream_size = int(stream.headers.get("content-length", 0))
                progress_callback.max_size = stream_size
                progress_callback.reset()
                with open(fs_path, "wb") as fhandle:
                    for chunk in stream.iter_content(chunk_size=64 * 1024):
                        if chunk:
                            fhandle.write(chunk)
                            progress_callback(len(chunk), stream_size)

        with open(record_filename, "w") as fh:
            fh.write(product.properties["downloadLink"])
        logger.debug("Download recorded in %s", record_filename)

        api.logout()

        # Check that the downloaded file is really a tar file
        if not tarfile.is_tarfile(fs_path):
            logger.warning(
                "Downloaded product is not a tar File. Please check its file type before using it"
            )
            new_fs_path = fs_path[:fs_path.index(".tar.gz")]
            shutil.move(fs_path, new_fs_path)
            return new_fs_path
        return self._finalize(fs_path, outputs_extension=".tar.gz", **kwargs)
예제 #10
0
파일: s3rest.py 프로젝트: saiplanner/eodag
    def download(self, product, auth=None, progress_callback=None, **kwargs):
        """Download method for S3 REST API.

        :param product: The EO product to download
        :type product: :class:`~eodag.api.product.EOProduct`
        :param auth: (optional) The configuration of a plugin of type Authentication
        :type auth: :class:`~eodag.config.PluginConfig`
        :param progress_callback: (optional) A method or a callable object
                                  which takes a current size and a maximum
                                  size as inputs and handle progress bar
                                  creation and update to give the user a
                                  feedback on the download progress
        :type progress_callback: :class:`~eodag.utils.ProgressCallback` or None
        :return: The absolute path to the downloaded product in the local filesystem
        :rtype: str
        """
        # get bucket urls
        bucket_name, prefix = self.get_bucket_name_and_prefix(product)

        if (bucket_name is None and "storageStatus" in product.properties
                and product.properties["storageStatus"] == OFFLINE_STATUS):
            raise NotAvailableError(
                "%s is not available for download on %s (status = %s)" % (
                    product.properties["title"],
                    self.provider,
                    product.properties["storageStatus"],
                ))

        bucket_url = urljoin(
            product.downloader.config.base_uri.strip("/") + "/", bucket_name)
        nodes_list_url = bucket_url + "?prefix=" + prefix.strip("/")

        # get nodes/files list contained in the bucket
        logger.debug("Retrieving product content from %s", nodes_list_url)
        bucket_contents = requests.get(nodes_list_url, auth=auth)
        try:
            bucket_contents.raise_for_status()
        except requests.HTTPError as err:
            # check if error is identified as auth_error in provider conf
            auth_errors = getattr(self.config, "auth_error_code", [None])
            if not isinstance(auth_errors, list):
                auth_errors = [auth_errors]
            if err.response.status_code in auth_errors:
                raise AuthenticationError(
                    "HTTP Error %s returned, %s\nPlease check your credentials for %s"
                    % (
                        err.response.status_code,
                        err.response.text.strip(),
                        self.provider,
                    ))
            # other error
            else:
                logger.exception(
                    "Could not get content from %s (provider:%s, plugin:%s)\n%s",
                    nodes_list_url,
                    self.provider,
                    self.__class__.__name__,
                    bucket_contents.text,
                )
                raise RequestError(str(err))
        try:
            xmldoc = minidom.parseString(bucket_contents.text)
        except ExpatError as err:
            logger.exception("Could not parse xml data from %s",
                             bucket_contents)
            raise DownloadError(str(err))
        nodes_xml_list = xmldoc.getElementsByTagName("Contents")

        if len(nodes_xml_list) == 0:
            logger.warning("Could not load any content from %s",
                           nodes_list_url)
        elif len(nodes_xml_list) == 1:
            # single file download
            product.remote_location = urljoin(
                bucket_url.strip("/") + "/", prefix.strip("/"))
            return HTTPDownload(self.provider, self.config).download(
                product=product,
                auth=auth,
                progress_callback=progress_callback,
                **kwargs)

        # destination product path
        outputs_prefix = kwargs.pop("ouputs_prefix",
                                    None) or self.config.outputs_prefix
        abs_outputs_prefix = os.path.abspath(outputs_prefix)
        product_local_path = os.path.join(abs_outputs_prefix,
                                          prefix.split("/")[-1])

        # .downloaded cache record directory
        download_records_dir = os.path.join(abs_outputs_prefix, ".downloaded")
        try:
            os.makedirs(download_records_dir)
        except OSError as exc:
            import errno

            if exc.errno != errno.EEXIST:  # Skip error if dir exists
                import traceback as tb

                logger.warning("Unable to create records directory. Got:\n%s",
                               tb.format_exc())
        # check if product has already been downloaded
        url_hash = hashlib.md5(
            product.remote_location.encode("utf-8")).hexdigest()
        record_filename = os.path.join(download_records_dir, url_hash)
        if os.path.isfile(record_filename) and os.path.exists(
                product_local_path):
            return product_local_path
        # Remove the record file if product_local_path is absent (e.g. it was deleted while record wasn't)
        elif os.path.isfile(record_filename):
            logger.debug("Record file found (%s) but not the actual file",
                         record_filename)
            logger.debug("Removing record file : %s", record_filename)
            os.remove(record_filename)

        # total size for progress_callback
        total_size = sum([
            int(node.firstChild.nodeValue)
            for node in xmldoc.getElementsByTagName("Size")
        ])

        # download each node key
        for node_xml in nodes_xml_list:
            node_key = node_xml.getElementsByTagName(
                "Key")[0].firstChild.nodeValue
            # As "Key", "Size" and "ETag" (md5 hash) can also be retrieved from node_xml
            node_url = urljoin(
                bucket_url.strip("/") + "/", node_key.strip("/"))
            # output file location
            local_filename = os.path.join(self.config.outputs_prefix,
                                          "/".join(node_key.split("/")[6:]))
            local_filename_dir = os.path.dirname(
                os.path.realpath(local_filename))
            if not os.path.isdir(local_filename_dir):
                os.makedirs(local_filename_dir)

            with requests.get(node_url, stream=True, auth=auth) as stream:
                try:
                    stream.raise_for_status()
                except HTTPError:
                    import traceback as tb

                    logger.error("Error while getting resource :\n%s",
                                 tb.format_exc())
                else:
                    with open(local_filename, "wb") as fhandle:
                        for chunk in stream.iter_content(chunk_size=64 * 1024):
                            if chunk:
                                fhandle.write(chunk)
                                progress_callback(len(chunk), total_size)

            # TODO: check md5 hash ?

        with open(record_filename, "w") as fh:
            fh.write(product.remote_location)
        logger.debug("Download recorded in %s", record_filename)

        return product_local_path
예제 #11
0
파일: usgs.py 프로젝트: saiplanner/eodag
    def query(self, product_type=None, **kwargs):
        """Search for data on USGS catalogues

        .. versionchanged::
            1.0

                * ``product_type`` is no longer mandatory
        """
        product_type = kwargs.get("productType")
        if product_type is None:
            return [], 0
        try:
            api.login(
                self.config.credentials["username"],
                self.config.credentials["password"],
                save=True,
            )
        except USGSError:
            raise AuthenticationError(
                "Please check your USGS credentials.") from None
        usgs_dataset = self.config.products[product_type]["dataset"]
        usgs_catalog_node = self.config.products[product_type]["catalog_node"]
        start_date = kwargs.pop("startTimeFromAscendingNode", None)
        end_date = kwargs.pop("completionTimeFromAscendingNode", None)
        geom = kwargs.pop("geometry", None)
        footprint = {}
        if hasattr(geom, "bounds"):
            (
                footprint["lonmin"],
                footprint["latmin"],
                footprint["lonmax"],
                footprint["latmax"],
            ) = geom.bounds
        else:
            footprint = geom

        # Configuration to generate the download url of search results
        result_summary_pattern = re.compile(
            r"^ID: .+, Acquisition Date: .+, Path: (?P<path>\d+), Row: (?P<row>\d+)$"  # noqa
        )
        # See https://pyformat.info/, on section "Padding and aligning strings" to
        # understand {path:0>3} and {row:0>3}.
        # It roughly means: 'if the string that will be passed as "path" has length < 3,
        # prepend as much "0"s as needed to reach length 3' and same for "row"
        dl_url_pattern = "{base_url}/L8/{path:0>3}/{row:0>3}/{entity}.tar.bz"

        final = []
        if footprint and len(footprint.keys()) == 4:  # a rectangle (or bbox)
            lower_left = {
                "longitude": footprint["lonmin"],
                "latitude": footprint["latmin"],
            }
            upper_right = {
                "longitude": footprint["lonmax"],
                "latitude": footprint["latmax"],
            }
        else:
            lower_left, upper_right = None, None
        try:
            results = api.search(
                usgs_dataset,
                usgs_catalog_node,
                start_date=start_date,
                end_date=end_date,
                ll=lower_left,
                ur=upper_right,
            )

            for result in results["data"]["results"]:
                r_lower_left = result["spatialFootprint"]["coordinates"][0][0]
                r_upper_right = result["spatialFootprint"]["coordinates"][0][2]
                summary_match = result_summary_pattern.match(
                    result["summary"]).groupdict()
                result["geometry"] = geometry.box(r_lower_left[0],
                                                  r_lower_left[1],
                                                  r_upper_right[0],
                                                  r_upper_right[1])

                # Same method as in base.py, Search.__init__()
                # Prepare the metadata mapping
                # Do a shallow copy, the structure is flat enough for this to be sufficient
                metas = DEFAULT_METADATA_MAPPING.copy()
                # Update the defaults with the mapping value. This will add any new key
                # added by the provider mapping that is not in the default metadata.
                # A deepcopy is done to prevent self.config.metadata_mapping from being modified when metas[metadata]
                # is a list and is modified
                metas.update(copy.deepcopy(self.config.metadata_mapping))
                metas = mtd_cfg_as_jsonpath(metas)

                result["productType"] = usgs_dataset

                product_properties = properties_from_json(result, metas)

                if getattr(self.config, "product_location_scheme",
                           "https") == "file":
                    product_properties["downloadLink"] = dl_url_pattern.format(
                        base_url="file://")
                else:
                    product_properties["downloadLink"] = dl_url_pattern.format(
                        base_url=self.config.google_base_url.rstrip("/"),
                        entity=result["entityId"],
                        **summary_match)

                final.append(
                    EOProduct(
                        productType=product_type,
                        provider=self.provider,
                        properties=product_properties,
                        geometry=footprint,
                    ))
        except USGSError as e:
            logger.debug(
                "Product type %s does not exist on catalogue %s",
                usgs_dataset,
                usgs_catalog_node,
            )
            logger.debug("Skipping error: %s", e)
        api.logout()
        return final, len(final)
예제 #12
0
    def download(self, product, auth=None, progress_callback=None, **kwargs):
        """Download method for AWS S3 API.

        :param product: The EO product to download
        :type product: :class:`~eodag.api.product.EOProduct`
        :param auth: (optional) The configuration of a plugin of type Authentication
        :type auth: :class:`~eodag.config.PluginConfig`
        :param progress_callback: (optional) A method or a callable object
                                  which takes a current size and a maximum
                                  size as inputs and handle progress bar
                                  creation and update to give the user a
                                  feedback on the download progress
        :type progress_callback: :class:`~eodag.utils.ProgressCallback` or None
        :return: The absolute path to the downloaded product in the local filesystem
        :rtype: str
        """
        product_conf = getattr(self.config, "products",
                               {}).get(product.product_type, {})

        build_safe = product_conf.get("build_safe", False)

        # product conf overrides provider conf for "flatten_top_dirs"
        flatten_top_dirs = product_conf.get(
            "flatten_top_dirs", getattr(self.config, "flatten_top_dirs",
                                        False))

        # xtra metadata needed for SAFE product
        if build_safe and "fetch_metadata" in product_conf.keys():
            fetch_format = product_conf["fetch_metadata"]["fetch_format"]
            update_metadata = product_conf["fetch_metadata"]["update_metadata"]
            fetch_url = product_conf["fetch_metadata"]["fetch_url"].format(
                **product.properties)
            if fetch_format == "json":
                logger.info("Fetching extra metadata from %s" % fetch_url)
                resp = requests.get(fetch_url)
                json_resp = resp.json()
                update_metadata = mtd_cfg_as_jsonpath(update_metadata)
                update_metadata = properties_from_json(json_resp,
                                                       update_metadata)
                product.properties.update(update_metadata)
            else:
                logger.warning(
                    "SAFE metadata fetch format %s not implemented" %
                    fetch_format)
        # if assets are defined, use them instead of scanning product.location
        if hasattr(product, "assets"):
            bucket_names_and_prefixes = []
            for complementary_url in getattr(product, "assets", {}).values():
                bucket_names_and_prefixes.append(
                    self.get_bucket_name_and_prefix(
                        product, complementary_url.get("href", "")))
        else:
            bucket_names_and_prefixes = [
                self.get_bucket_name_and_prefix(product)
            ]

        # add complementary urls
        for complementary_url_key in product_conf.get("complementary_url_key",
                                                      []):
            bucket_names_and_prefixes.append(
                self.get_bucket_name_and_prefix(
                    product, product.properties[complementary_url_key]))

        # prepare download & create dirs
        product_local_path, record_filename = self._prepare_download(
            product, **kwargs)
        if not product_local_path or not record_filename:
            return product_local_path
        product_local_path = product_local_path.replace(".zip", "")
        # remove existing incomplete file
        if os.path.isfile(product_local_path):
            os.remove(product_local_path)
        # create product dest dir
        if not os.path.isdir(product_local_path):
            os.makedirs(product_local_path)

        # progress bar init
        if progress_callback is None:
            progress_callback = get_progress_callback()
        progress_callback.desc = product.properties.get("id", "")
        progress_callback.position = 1

        # authenticate & get product size
        authenticated_objects = {}
        total_size = 0
        auth_error_messages = set()
        for idx, pack in enumerate(bucket_names_and_prefixes):
            try:
                bucket_name, prefix = pack
                if bucket_name not in authenticated_objects:
                    # get Prefixes longest common base path
                    common_prefix = ""
                    prefix_split = prefix.split("/")
                    prefixes_in_bucket = len([
                        p for b, p in bucket_names_and_prefixes
                        if b == bucket_name
                    ])
                    for i in range(1, len(prefix_split)):
                        common_prefix = "/".join(prefix_split[0:i])
                        if (len([
                                p for b, p in bucket_names_and_prefixes
                                if b == bucket_name and common_prefix in p
                        ]) < prefixes_in_bucket):
                            common_prefix = "/".join(prefix_split[0:i - 1])
                            break
                    # connect to aws s3 and get bucket auhenticated objects
                    s3_objects = self.get_authenticated_objects(
                        bucket_name, common_prefix, auth)
                    authenticated_objects[bucket_name] = s3_objects
                else:
                    s3_objects = authenticated_objects[bucket_name]

                total_size += sum(
                    [p.size for p in s3_objects.filter(Prefix=prefix)])

            except AuthenticationError as e:
                logger.warning("Unexpected error: %s" % e)
                logger.warning("Skipping %s/%s" % (bucket_name, prefix))
                auth_error_messages.add(str(e))
            except ClientError as e:
                err = e.response["Error"]
                auth_messages = [
                    "AccessDenied",
                    "InvalidAccessKeyId",
                    "SignatureDoesNotMatch",
                ]
                if err["Code"] in auth_messages and "key" in err[
                        "Message"].lower():
                    raise AuthenticationError(
                        "HTTP error {} returned\n{}: {}\nPlease check your credentials for {}"
                        .format(
                            e.response["ResponseMetadata"]["HTTPStatusCode"],
                            err["Code"],
                            err["Message"],
                            self.provider,
                        ))
                logger.warning("Unexpected error: %s" % e)
                logger.warning("Skipping %s/%s" % (bucket_name, prefix))
                auth_error_messages.add(str(e))

        # could not auth on any bucket
        if not authenticated_objects:
            raise AuthenticationError(", ".join(auth_error_messages))

        # bucket_names_and_prefixes with unauthenticated items filtered out
        auth_bucket_names_and_prefixes = [
            p for p in bucket_names_and_prefixes
            if p[0] in authenticated_objects.keys()
        ]

        # download
        progress_callback.max_size = total_size
        progress_callback.reset()
        for bucket_name, prefix in auth_bucket_names_and_prefixes:
            try:
                s3_objects = authenticated_objects[bucket_name]

                for product_chunk in s3_objects.filter(Prefix=prefix, ):
                    chunck_rel_path = self.get_chunck_dest_path(
                        product,
                        product_chunk,
                        build_safe=build_safe,
                        dir_prefix=prefix,
                    )
                    chunck_abs_path = os.path.join(product_local_path,
                                                   chunck_rel_path)
                    chunck_abs_path_dir = os.path.dirname(chunck_abs_path)
                    if not os.path.isdir(chunck_abs_path_dir):
                        os.makedirs(chunck_abs_path_dir)

                    if not os.path.isfile(chunck_abs_path):
                        product_chunk.Bucket().download_file(
                            product_chunk.key,
                            chunck_abs_path,
                            ExtraArgs=getattr(s3_objects, "_params", {}),
                            Callback=progress_callback,
                        )

            except AuthenticationError as e:
                logger.warning("Unexpected error: %s" % e)
                logger.warning("Skipping %s/%s" % (bucket_name, prefix))
            except ClientError as e:
                err = e.response["Error"]
                auth_messages = [
                    "AccessDenied",
                    "InvalidAccessKeyId",
                    "SignatureDoesNotMatch",
                ]
                if err["Code"] in auth_messages and "key" in err[
                        "Message"].lower():
                    raise AuthenticationError(
                        "HTTP error {} returned\n{}: {}\nPlease check your credentials for {}"
                        .format(
                            e.response["ResponseMetadata"]["HTTPStatusCode"],
                            err["Code"],
                            err["Message"],
                            self.provider,
                        ))
                logger.warning("Unexpected error: %s" % e)
                logger.warning("Skipping %s/%s" % (bucket_name, prefix))

        # finalize safe product
        if build_safe and "S2_MSI" in product.product_type:
            self.finalize_s2_safe_product(product_local_path)
        # flatten directory structure
        elif flatten_top_dirs:
            tmp_product_local_path = "%s-tmp" % product_local_path
            for d, dirs, files in os.walk(product_local_path):
                if len(files) != 0:
                    shutil.copytree(d, tmp_product_local_path)
                    shutil.rmtree(product_local_path)
                    os.rename(tmp_product_local_path, product_local_path)
                    break

        # save hash/record file
        with open(record_filename, "w") as fh:
            fh.write(product.remote_location)
        logger.debug("Download recorded in %s", record_filename)

        return product_local_path
예제 #13
0
    def download(self, product, auth=None, progress_callback=None, **kwargs):
        """Download method for AWS S3 API.

        :param product: The EO product to download
        :type product: :class:`~eodag.api.product.EOProduct`
        :param auth: (optional) The configuration of a plugin of type Authentication
        :type auth: :class:`~eodag.config.PluginConfig`
        :param progress_callback: (optional) A method or a callable object
                                  which takes a current size and a maximum
                                  size as inputs and handle progress bar
                                  creation and update to give the user a
                                  feedback on the download progress
        :type progress_callback: :class:`~eodag.utils.ProgressCallback` or None
        :return: The absolute path to the downloaded product in the local filesystem
        :rtype: str
        """
        product_conf = getattr(self.config, "products",
                               {}).get(product.product_type, {})

        build_safe = product_conf.get("build_safe", False)

        # product conf overrides provider conf for "flatten_top_dirs"
        flatten_top_dirs = product_conf.get(
            "flatten_top_dirs", getattr(self.config, "flatten_top_dirs",
                                        False))

        # xtra metadata needed for SAFE product
        if build_safe and "fetch_metadata" in product_conf.keys():
            fetch_format = product_conf["fetch_metadata"]["fetch_format"]
            update_metadata = product_conf["fetch_metadata"]["update_metadata"]
            fetch_url = product_conf["fetch_metadata"]["fetch_url"].format(
                **product.properties)
            if fetch_format == "json":
                logger.info("Fetching extra metadata from %s" % fetch_url)
                resp = requests.get(fetch_url)
                json_resp = resp.json()
                update_metadata = mtd_cfg_as_jsonpath(update_metadata)
                update_metadata = properties_from_json(json_resp,
                                                       update_metadata)
                product.properties.update(update_metadata)
            else:
                logger.warning(
                    "SAFE metadata fetch format %s not implemented" %
                    fetch_format)
        # if assets are defined, use them instead of scanning product.location
        if hasattr(product, "assets"):
            bucket_names_and_prefixes = []
            for complementary_url in getattr(product, "assets", {}).values():
                bucket_names_and_prefixes.append(
                    self.get_bucket_name_and_prefix(
                        product, complementary_url.get("href", "")))
        else:
            bucket_names_and_prefixes = [
                self.get_bucket_name_and_prefix(product)
            ]

        # add complementary urls
        for complementary_url_key in product_conf.get("complementary_url_key",
                                                      []):
            bucket_names_and_prefixes.append(
                self.get_bucket_name_and_prefix(
                    product, product.properties[complementary_url_key]))

        # prepare download & create dirs
        product_local_path, record_filename = self._prepare_download(
            product, **kwargs)
        if not product_local_path or not record_filename:
            return product_local_path
        product_local_path = product_local_path.replace(".zip", "")
        # remove existing incomplete file
        if os.path.isfile(product_local_path):
            os.remove(product_local_path)
        # create product dest dir
        if not os.path.isdir(product_local_path):
            os.makedirs(product_local_path)

        with tqdm(
                total=len(bucket_names_and_prefixes),
                unit="parts",
                desc="Downloading product parts",
        ) as bar:

            for bucket_name, prefix in bucket_names_and_prefixes:
                try:
                    # connect to aws s3
                    access_key, access_secret = auth
                    s3 = boto3.resource(
                        "s3",
                        aws_access_key_id=access_key,
                        aws_secret_access_key=access_secret,
                    )
                    bucket = s3.Bucket(bucket_name)

                    total_size = sum([
                        p.size for p in bucket.objects.filter(
                            Prefix=prefix, RequestPayer="requester")
                    ])
                    progress_callback.max_size = total_size
                    for product_chunk in bucket.objects.filter(
                            Prefix=prefix, RequestPayer="requester"):
                        chunck_rel_path = self.get_chunck_dest_path(
                            product,
                            product_chunk,
                            build_safe=build_safe,
                            dir_prefix=prefix,
                        )
                        chunck_abs_path = os.path.join(product_local_path,
                                                       chunck_rel_path)
                        chunck_abs_path_dir = os.path.dirname(chunck_abs_path)
                        if not os.path.isdir(chunck_abs_path_dir):
                            os.makedirs(chunck_abs_path_dir)

                        if not os.path.isfile(chunck_abs_path):
                            bucket.download_file(
                                product_chunk.key,
                                chunck_abs_path,
                                ExtraArgs={"RequestPayer": "requester"},
                                Callback=progress_callback,
                            )
                except ClientError as e:
                    err = e.response["Error"]
                    auth_messages = [
                        "InvalidAccessKeyId", "SignatureDoesNotMatch"
                    ]
                    if err["Code"] in auth_messages and "key" in err[
                            "Message"].lower():
                        raise AuthenticationError(
                            "HTTP error {} returned\n{}: {}\nPlease check your credentials for {}"
                            .format(
                                e.response["ResponseMetadata"]
                                ["HTTPStatusCode"],
                                err["Code"],
                                err["Message"],
                                self.provider,
                            ))
                    logger.warning("Unexpected error: %s" % e)
                    logger.warning("Skipping %s/%s" % (bucket_name, prefix))
                bar.update(1)

        # finalize safe product
        if build_safe and "S2_MSI" in product.product_type:
            self.finalize_s2_safe_product(product_local_path)
        # flatten directory structure
        elif flatten_top_dirs:
            tmp_product_local_path = "%s-tmp" % product_local_path
            for d, dirs, files in os.walk(product_local_path):
                if len(files) != 0:
                    shutil.copytree(d, tmp_product_local_path)
                    shutil.rmtree(product_local_path)
                    os.rename(tmp_product_local_path, product_local_path)
                    break

        # save hash/record file
        with open(record_filename, "w") as fh:
            fh.write(product.remote_location)
        logger.debug("Download recorded in %s", record_filename)

        return product_local_path