def get_job_name(): """ Get the current job name. :return The job name: """ if env.get_env(_JOB_NAME_ENV_VAR) is not None: return env.get_env(_JOB_NAME_ENV_VAR)
def get_worker_index(): """ Get the current worker index. :return: The worker index: """ # Get TensorFlow worker index if env.get_env(_TF_CONFIG) is not None: tf_config = json.loads(os.environ.get(_TF_CONFIG)) task_config = tf_config.get(_TASK) task_type = task_config.get(_JOB_NAME) task_index = task_config.get(_TASK_INDEX) worker_index = task_type + '-' + str(task_index) elif env.get_env(_CLUSTER_SPEC) is not None: cluster_spec = json.loads(os.environ.get(_CLUSTER_SPEC)) task_config = cluster_spec.get(_TASK) task_type = task_config.get(_JOB_NAME) task_index = task_config.get(_TASK_INDEX) worker_index = task_type + '-' + str(task_index) # Get PyTorch worker index elif env.get_env(_RANK) is not None: rank = env.get_env(_RANK) if rank == "0": worker_index = "master-0" else: worker_index = "worker-" + rank # Set worker index to "worker-0" When running local training else: worker_index = "worker-0" return worker_index
def get_job_id(): """ Get the current experiment id. :return The experiment id: """ # Get yarn application or K8s experiment ID when running distributed training if env.get_env(_JOB_ID_ENV_VAR) is not None: return env.get_env(_JOB_ID_ENV_VAR) else: # set Random ID when running local training job_id = uuid.uuid4().hex os.environ[_JOB_ID_ENV_VAR] = job_id return job_id
def get_db_uri() -> str: """ Get the current DB URI. :return: The DB URI. """ global _db_uri if _db_uri is not None: return _db_uri elif env.get_env(_DB_URI_ENV_VAR) is not None: return env.get_env(_DB_URI_ENV_VAR) else: return DEFAULT_SUBMARINE_JDBC_URL
def get_tracking_uri(): """ Get the current tracking URI. This may not correspond to the tracking URI of the currently active run, since the tracking URI can be updated via ``set_tracking_uri``. :return: The tracking URI. """ global _tracking_uri if _tracking_uri is not None: return _tracking_uri elif env.get_env(_TRACKING_URI_ENV_VAR) is not None: return env.get_env(_TRACKING_URI_ENV_VAR) else: return DEFAULT_SUBMARINE_JDBC_URL
def is_tracking_uri_set(): """Returns True if the tracking URI has been set, False otherwise.""" if _tracking_uri or env.get_env(_TRACKING_URI_ENV_VAR): return True return False
def test_get_env(): environ["test"] = "hello" assert get_env("test") == "hello"
def is_db_uri_set() -> bool: """Returns True if the DB URI has been set, False otherwise.""" if _db_uri or env.get_env(_DB_URI_ENV_VAR): return True return False
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from comet_ml import Experiment import submarine from submarine.utils import env #experiment = Experiment(api_key="ej6XeyCVjqHM8uLDNj5VGrzjP", # project_name="testing", workspace="pingsutw") if __name__ == "__main__": submarine.set_tracking_uri( "mysql+pymysql://submarine:password@submarine-database/submarine") print("TF_CONFIG", env.get_env("TF_CONFIG")) print("JOB_NAME: ", env.get_env("JOB_NAME")) print("TYPE: ", env.get_env("TPYE")) print("TASK_INDEX: ", env.get_env("TASK_INDEX")) print("CLUSTER_SPEC: ", env.get_env("CLUSTER_SPEC")) print("RANK: ", env.get_env("RANK")) submarine.log_param("max_iter", 100) submarine.log_param("learning_rate", 0.0001) submarine.log_param("alpha", 20) submarine.log_param("batch_size", 256) submarine.log_metric("score", 2) submarine.log_metric("score", 5) submarine.log_metric("score", 8) submarine.log_metric("score", 5)