Ejemplo n.º 1
0
def test_import_missing_dependency():
    with pytest.raises(
        ImportError,
        match=".*%s.*PURPOSE.*pip install 'bionic\\[%s\\]'.*"
        % (TEST_PACKAGE_NAME, TEST_EXTRA_NAME),
    ):
        import_optional_dependency(TEST_PACKAGE_NAME, purpose="PURPOSE")
Ejemplo n.º 2
0
def loky_executor():
    loky = import_optional_dependency("loky", purpose="parallel execution")
    return loky.get_reusable_executor(
        max_workers=None,
        initializer=logging_initializer,
        initargs=(get_singleton_manager().logging_queue,),
    )
Ejemplo n.º 3
0
    def _stage(self):
        cloudpickle = import_optional_dependency("cloudpickle")

        path = self.inputs_uri()
        logging.info(f"Staging task {self.name} at {path}")

        with self.gcs_fs.open(path, "wb") as f:
            cloudpickle.dump(self, f)
Ejemplo n.º 4
0
def get_aip_client(cache_value=True):
    if cache_value:
        global _cached_aip_client
        if _cached_aip_client is None:
            _cached_aip_client = get_aip_client(cache_value=False)
        return _cached_aip_client

    discovery = import_optional_dependency("googleapiclient.discovery",
                                           raise_on_missing=True)
    logger.info("Initializing AIP client ...")
    return discovery.build("ml", "v1", cache_discovery=False)
Ejemplo n.º 5
0
def _run(ipath, gcs_fs):
    cloudpickle = import_optional_dependency("cloudpickle")

    with gcs_fs.open(ipath, "rb") as f:
        task = cloudpickle.load(f)

    # Now that we have the task, set up logging.
    _set_up_logging(task.job_id(), task.config.project_name)
    logging.info(f"Read task from {ipath}")

    result = task.function()

    opath = task.output_uri()
    logging.info(f"Uploading result to {opath}")
    with gcs_fs.open(opath, "wb") as f:
        pickle.dump(result, f)
Ejemplo n.º 6
0
def _set_up_logging(job_id, project_id):
    if os.environ.get("BIONIC_NO_STACKDRIVER", False):
        return

    # TODO This is the ID of the hyperparameter tuning trial currently
    # running on this VM. This field is only set if the current
    # training job is a hyperparameter tuning job. Conductor uses this
    # environment variable but AIP documentation suggests us to use
    # TF_CONFIG. Check whether we need to update this env variable.
    # Find more details on TF_CONFIG at this link:
    # https://cloud.google.com/ai-platform/training/docs/distributed-training-details
    trial_id = os.environ.get("CLOUD_ML_TRIAL_ID", None)

    glogging = import_optional_dependency("google.cloud.logging")

    client = glogging.Client(project=project_id)
    resource = glogging.resource.Resource(
        type="ml_job",
        # AIP expects a default task_name for the master cluster. We
        # use a placeholder value till we start using clusters. Once we
        # do, it should be configured based on the cluster.
        labels=dict(job_id=job_id,
                    project_id=project_id,
                    task_name="master-replica-0"),
    )
    labels = None
    if trial_id is not None:
        # Enable grouping by trial when present.
        labels = {"ml.googleapis.com/trial_id": trial_id}

    # Enable only the cloud logger to avoid duplicate messages.
    handler = glogging.handlers.handlers.CloudLoggingHandler(client,
                                                             resource=resource,
                                                             labels=labels)
    root_logger = logging.getLogger()
    # Remote the StreamHandler. Any logs logged by it shows up as error
    # logs in Stackdriver.
    root_logger.handlers = []
    # We should ideally make this configurable, but till then, let's
    # set the level to DEBUG to write all the logs. It's not hard to
    # filter using log level on Stackdriver so it doesn't create too
    # much noise anyway.
    root_logger.setLevel(logging.DEBUG)
    root_logger.addHandler(handler)
    for logger_name in glogging.handlers.handlers.EXCLUDED_LOGGER_DEFAULTS:
        logging.getLogger(logger_name).propagate = False
Ejemplo n.º 7
0
def test_import_unrecognized_dependency():
    with pytest.raises(AssertionError):
        import_optional_dependency("_UNKNOWN_PACKAGE_", purpose="PURPOSE")
Ejemplo n.º 8
0
def test_import_missing_dependency_without_raising():
    module = import_optional_dependency(TEST_PACKAGE_NAME, raise_on_missing=False)
    assert module is None
Ejemplo n.º 9
0
def get_gcp_project_id():
    google_auth = import_optional_dependency(
        "google.auth", purpose="Get GCP project id from the environment")
    _, project = google_auth.default()
    return project