コード例 #1
0
ファイル: serve.py プロジェクト: zongzhenh/cortex
def get_spec(provider, storage, cache_dir, spec_path):
    if provider == "local":
        return util.read_msgpack(spec_path)

    local_spec_path = os.path.join(cache_dir, "api_spec.msgpack")
    _, key = S3.deconstruct_s3_path(spec_path)
    storage.download_file(key, local_spec_path)
    return util.read_msgpack(local_spec_path)
コード例 #2
0
    def get_obj(self, key):
        if key in self._obj_cache:
            return self._obj_cache[key]

        cache_path = os.path.join(self.cache_dir, key)
        self.download_file(key, cache_path)
        self._obj_cache[key] = util.read_msgpack(cache_path)
        return self._obj_cache[key]
コード例 #3
0
ファイル: context.py プロジェクト: gulahmed/cortex
    def __init__(self, **kwargs):
        if "cache_dir" in kwargs:
            self.cache_dir = kwargs["cache_dir"]
        elif "local_path" in kwargs:
            local_path_dir = os.path.dirname(os.path.abspath(kwargs["local_path"]))
            self.cache_dir = os.path.join(local_path_dir, "cache")
        else:
            raise ValueError("cache_dir must be specified (or inferred from local_path)")
        util.mkdir_p(self.cache_dir)

        if "local_path" in kwargs:
            self.ctx = util.read_msgpack(kwargs["local_path"])
        elif "obj" in kwargs:
            self.ctx = kwargs["obj"]
        elif "raw_obj" in kwargs:
            self.ctx = kwargs["raw_obj"]
        elif "s3_path":
            local_ctx_path = os.path.join(self.cache_dir, "context.msgpack")
            bucket, key = S3.deconstruct_s3_path(kwargs["s3_path"])
            S3(bucket, client_config={}).download_file(key, local_ctx_path)
            self.ctx = util.read_msgpack(local_ctx_path)
        else:
            raise ValueError("invalid context args: " + kwargs)

        self.workload_id = kwargs.get("workload_id")

        self.id = self.ctx["id"]
        self.key = self.ctx["key"]
        self.metadata_root = self.ctx["metadata_root"]
        self.cortex_config = self.ctx["cortex_config"]
        self.deployment_version = self.ctx["deployment_version"]
        self.root = self.ctx["root"]
        self.status_prefix = self.ctx["status_prefix"]
        self.app = self.ctx["app"]
        self.apis = self.ctx["apis"] or {}
        self.api_version = self.cortex_config["api_version"]
        self.monitoring = None
        self.project_id = self.ctx["project_id"]
        self.project_key = self.ctx["project_key"]

        if "local_storage_path" in kwargs:
            self.storage = LocalStorage(base_dir=kwargs["local_storage_path"])
        else:
            self.storage = S3(
                bucket=self.cortex_config["bucket"],
                region=self.cortex_config["region"],
                client_config={},
            )

        host_ip = os.environ["HOST_IP"]
        datadog.initialize(statsd_host=host_ip, statsd_port="8125")
        self.statsd = datadog.statsd

        if self.api_version != consts.CORTEX_VERSION:
            raise ValueError(
                "API version mismatch (Context: {}, Image: {})".format(
                    self.api_version, consts.CORTEX_VERSION
                )
            )

        # This affects Tensorflow S3 access
        os.environ["AWS_REGION"] = self.cortex_config.get("region", "")

        # ID maps
        self.apis_id_map = ResourceMap(self.apis) if self.apis else None
        self.id_map = self.apis_id_map
コード例 #4
0
    def __init__(self, **kwargs):
        if "cache_dir" in kwargs:
            self.cache_dir = kwargs["cache_dir"]
        elif "local_path" in kwargs:
            local_path_dir = os.path.dirname(os.path.abspath(kwargs["local_path"]))
            self.cache_dir = os.path.join(local_path_dir, "cache")
        else:
            raise ValueError("cache_dir must be specified (or inferred from local_path)")
        util.mkdir_p(self.cache_dir)

        if "local_path" in kwargs:
            ctx_raw = util.read_msgpack(kwargs["local_path"])
            self.ctx = _deserialize_raw_ctx(ctx_raw)
        elif "obj" in kwargs:
            self.ctx = kwargs["obj"]
        elif "raw_obj" in kwargs:
            ctx_raw = kwargs["raw_obj"]
            self.ctx = _deserialize_raw_ctx(ctx_raw)
        elif "s3_path":
            local_ctx_path = os.path.join(self.cache_dir, "context.msgpack")
            bucket, key = S3.deconstruct_s3_path(kwargs["s3_path"])
            S3(bucket, client_config={}).download_file(key, local_ctx_path)
            ctx_raw = util.read_msgpack(local_ctx_path)
            self.ctx = _deserialize_raw_ctx(ctx_raw)
        else:
            raise ValueError("invalid context args: " + kwargs)

        self.workload_id = kwargs.get("workload_id")

        self.id = self.ctx["id"]
        self.key = self.ctx["key"]
        self.cortex_config = self.ctx["cortex_config"]
        self.dataset_version = self.ctx["dataset_version"]
        self.root = self.ctx["root"]
        self.raw_dataset = self.ctx["raw_dataset"]
        self.status_prefix = self.ctx["status_prefix"]
        self.app = self.ctx["app"]
        self.environment = self.ctx["environment"]
        self.python_packages = self.ctx["python_packages"] or {}
        self.raw_columns = self.ctx["raw_columns"] or {}
        self.transformed_columns = self.ctx["transformed_columns"] or {}
        self.transformers = self.ctx["transformers"] or {}
        self.aggregators = self.ctx["aggregators"] or {}
        self.aggregates = self.ctx["aggregates"] or {}
        self.constants = self.ctx["constants"] or {}
        self.models = self.ctx["models"] or {}
        self.estimators = self.ctx["estimators"] or {}
        self.apis = self.ctx["apis"] or {}
        self.training_datasets = {k: v["dataset"] for k, v in self.models.items()}
        self.api_version = self.cortex_config["api_version"]

        if "local_storage_path" in kwargs:
            self.storage = LocalStorage(base_dir=kwargs["local_storage_path"])
        else:
            self.storage = S3(
                bucket=self.cortex_config["bucket"],
                region=self.cortex_config["region"],
                client_config={},
            )

        if self.api_version != consts.CORTEX_VERSION:
            raise ValueError(
                "API version mismatch (Context: {}, Image: {})".format(
                    self.api_version, consts.CORTEX_VERSION
                )
            )

        self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns)

        self.raw_column_names = list(self.raw_columns.keys())
        self.transformed_column_names = list(self.transformed_columns.keys())
        self.column_names = list(self.columns.keys())

        # Internal caches
        self._transformer_impls = {}
        self._aggregator_impls = {}
        self._estimator_impls = {}
        self._metadatas = {}
        self._obj_cache = {}
        self.spark_uploaded_impls = {}

        # This affects Tensorflow S3 access
        os.environ["AWS_REGION"] = self.cortex_config.get("region", "")

        # Id map
        self.pp_id_map = ResourceMap(self.python_packages) if self.python_packages else None
        self.rf_id_map = ResourceMap(self.raw_columns) if self.raw_columns else None
        self.ag_id_map = ResourceMap(self.aggregates) if self.aggregates else None
        self.tf_id_map = ResourceMap(self.transformed_columns) if self.transformed_columns else None
        self.td_id_map = ResourceMap(self.training_datasets) if self.training_datasets else None
        self.models_id_map = ResourceMap(self.models) if self.models else None
        self.apis_id_map = ResourceMap(self.apis) if self.apis else None
        self.constants_id_map = ResourceMap(self.constants) if self.constants else None
        self.id_map = util.merge_dicts_overwrite(
            self.pp_id_map,
            self.rf_id_map,
            self.ag_id_map,
            self.tf_id_map,
            self.td_id_map,
            self.models_id_map,
            self.apis_id_map,
            self.constants_id_map,
        )
コード例 #5
0
ファイル: serve.py プロジェクト: lezoudali/cortex
def get_spec(cache_dir, s3_path):
    local_spec_path = os.path.join(cache_dir, "api_spec.msgpack")
    bucket, key = S3.deconstruct_s3_path(s3_path)
    S3(bucket, client_config={}).download_file(key, local_spec_path)
    return util.read_msgpack(local_spec_path)