def get_spec(provider, storage, cache_dir, spec_path): if provider == "local": return util.read_msgpack(spec_path) local_spec_path = os.path.join(cache_dir, "api_spec.msgpack") _, key = S3.deconstruct_s3_path(spec_path) storage.download_file(key, local_spec_path) return util.read_msgpack(local_spec_path)
def get_obj(self, key): if key in self._obj_cache: return self._obj_cache[key] cache_path = os.path.join(self.cache_dir, key) self.download_file(key, cache_path) self._obj_cache[key] = util.read_msgpack(cache_path) return self._obj_cache[key]
def __init__(self, **kwargs): if "cache_dir" in kwargs: self.cache_dir = kwargs["cache_dir"] elif "local_path" in kwargs: local_path_dir = os.path.dirname(os.path.abspath(kwargs["local_path"])) self.cache_dir = os.path.join(local_path_dir, "cache") else: raise ValueError("cache_dir must be specified (or inferred from local_path)") util.mkdir_p(self.cache_dir) if "local_path" in kwargs: self.ctx = util.read_msgpack(kwargs["local_path"]) elif "obj" in kwargs: self.ctx = kwargs["obj"] elif "raw_obj" in kwargs: self.ctx = kwargs["raw_obj"] elif "s3_path": local_ctx_path = os.path.join(self.cache_dir, "context.msgpack") bucket, key = S3.deconstruct_s3_path(kwargs["s3_path"]) S3(bucket, client_config={}).download_file(key, local_ctx_path) self.ctx = util.read_msgpack(local_ctx_path) else: raise ValueError("invalid context args: " + kwargs) self.workload_id = kwargs.get("workload_id") self.id = self.ctx["id"] self.key = self.ctx["key"] self.metadata_root = self.ctx["metadata_root"] self.cortex_config = self.ctx["cortex_config"] self.deployment_version = self.ctx["deployment_version"] self.root = self.ctx["root"] self.status_prefix = self.ctx["status_prefix"] self.app = self.ctx["app"] self.apis = self.ctx["apis"] or {} self.api_version = self.cortex_config["api_version"] self.monitoring = None self.project_id = self.ctx["project_id"] self.project_key = self.ctx["project_key"] if "local_storage_path" in kwargs: self.storage = LocalStorage(base_dir=kwargs["local_storage_path"]) else: self.storage = S3( bucket=self.cortex_config["bucket"], region=self.cortex_config["region"], client_config={}, ) host_ip = os.environ["HOST_IP"] datadog.initialize(statsd_host=host_ip, statsd_port="8125") self.statsd = datadog.statsd if self.api_version != consts.CORTEX_VERSION: raise ValueError( "API version mismatch (Context: {}, Image: {})".format( self.api_version, consts.CORTEX_VERSION ) ) # This affects Tensorflow S3 access os.environ["AWS_REGION"] = self.cortex_config.get("region", "") # ID maps self.apis_id_map = ResourceMap(self.apis) if self.apis else None self.id_map = self.apis_id_map
def __init__(self, **kwargs): if "cache_dir" in kwargs: self.cache_dir = kwargs["cache_dir"] elif "local_path" in kwargs: local_path_dir = os.path.dirname(os.path.abspath(kwargs["local_path"])) self.cache_dir = os.path.join(local_path_dir, "cache") else: raise ValueError("cache_dir must be specified (or inferred from local_path)") util.mkdir_p(self.cache_dir) if "local_path" in kwargs: ctx_raw = util.read_msgpack(kwargs["local_path"]) self.ctx = _deserialize_raw_ctx(ctx_raw) elif "obj" in kwargs: self.ctx = kwargs["obj"] elif "raw_obj" in kwargs: ctx_raw = kwargs["raw_obj"] self.ctx = _deserialize_raw_ctx(ctx_raw) elif "s3_path": local_ctx_path = os.path.join(self.cache_dir, "context.msgpack") bucket, key = S3.deconstruct_s3_path(kwargs["s3_path"]) S3(bucket, client_config={}).download_file(key, local_ctx_path) ctx_raw = util.read_msgpack(local_ctx_path) self.ctx = _deserialize_raw_ctx(ctx_raw) else: raise ValueError("invalid context args: " + kwargs) self.workload_id = kwargs.get("workload_id") self.id = self.ctx["id"] self.key = self.ctx["key"] self.cortex_config = self.ctx["cortex_config"] self.dataset_version = self.ctx["dataset_version"] self.root = self.ctx["root"] self.raw_dataset = self.ctx["raw_dataset"] self.status_prefix = self.ctx["status_prefix"] self.app = self.ctx["app"] self.environment = self.ctx["environment"] self.python_packages = self.ctx["python_packages"] or {} self.raw_columns = self.ctx["raw_columns"] or {} self.transformed_columns = self.ctx["transformed_columns"] or {} self.transformers = self.ctx["transformers"] or {} self.aggregators = self.ctx["aggregators"] or {} self.aggregates = self.ctx["aggregates"] or {} self.constants = self.ctx["constants"] or {} self.models = self.ctx["models"] or {} self.estimators = self.ctx["estimators"] or {} self.apis = self.ctx["apis"] or {} self.training_datasets = {k: v["dataset"] for k, v in self.models.items()} self.api_version = self.cortex_config["api_version"] if "local_storage_path" in kwargs: self.storage = LocalStorage(base_dir=kwargs["local_storage_path"]) else: self.storage = S3( bucket=self.cortex_config["bucket"], region=self.cortex_config["region"], client_config={}, ) if self.api_version != consts.CORTEX_VERSION: raise ValueError( "API version mismatch (Context: {}, Image: {})".format( self.api_version, consts.CORTEX_VERSION ) ) self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns) self.raw_column_names = list(self.raw_columns.keys()) self.transformed_column_names = list(self.transformed_columns.keys()) self.column_names = list(self.columns.keys()) # Internal caches self._transformer_impls = {} self._aggregator_impls = {} self._estimator_impls = {} self._metadatas = {} self._obj_cache = {} self.spark_uploaded_impls = {} # This affects Tensorflow S3 access os.environ["AWS_REGION"] = self.cortex_config.get("region", "") # Id map self.pp_id_map = ResourceMap(self.python_packages) if self.python_packages else None self.rf_id_map = ResourceMap(self.raw_columns) if self.raw_columns else None self.ag_id_map = ResourceMap(self.aggregates) if self.aggregates else None self.tf_id_map = ResourceMap(self.transformed_columns) if self.transformed_columns else None self.td_id_map = ResourceMap(self.training_datasets) if self.training_datasets else None self.models_id_map = ResourceMap(self.models) if self.models else None self.apis_id_map = ResourceMap(self.apis) if self.apis else None self.constants_id_map = ResourceMap(self.constants) if self.constants else None self.id_map = util.merge_dicts_overwrite( self.pp_id_map, self.rf_id_map, self.ag_id_map, self.tf_id_map, self.td_id_map, self.models_id_map, self.apis_id_map, self.constants_id_map, )
def get_spec(cache_dir, s3_path): local_spec_path = os.path.join(cache_dir, "api_spec.msgpack") bucket, key = S3.deconstruct_s3_path(s3_path) S3(bucket, client_config={}).download_file(key, local_spec_path) return util.read_msgpack(local_spec_path)