def _UpdateDefaultEnvProtoByMultiClientEnvVars(env_proto): assert HasAllMultiClientEnvVars() def str2int(env_config): assert env_config.isdigit() return int(env_config) bootstrap_conf = ctrl_bootstrap_pb.BootstrapConf() master_addr = ctrl_bootstrap_pb.Address() master_addr.host = os.getenv("MASTER_ADDR") master_addr.port = str2int(os.getenv("MASTER_PORT")) bootstrap_conf.master_addr.CopyFrom(master_addr) bootstrap_conf.world_size = str2int(os.getenv("WORLD_SIZE")) bootstrap_conf.rank = str2int(os.getenv("RANK")) env_proto.ctrl_bootstrap_conf.CopyFrom(bootstrap_conf)
def _MakeBootstrapConf(bootstrap_info: dict): global config_master_addr assert config_master_addr.HasField("host"), "must config master host first" assert config_master_addr.HasField("port"), "must config master port first" assert config_world_size != 0, "must config world size first" bootstrap_conf = ctrl_bootstrap_pb.BootstrapConf() bootstrap_conf.master_addr.CopyFrom(config_master_addr) bootstrap_conf.world_size = config_world_size assert "rank" in bootstrap_info bootstrap_conf.rank = bootstrap_info["rank"] if "host" in bootstrap_info: bootstrap_conf.host = bootstrap_info["host"] global config_bootstrap_ctrl_port if config_bootstrap_ctrl_port != 0: bootstrap_conf.ctrl_port = config_bootstrap_ctrl_port return bootstrap_conf
def _UpdateDefaultEnvProtoByMultiClientEnvVars(env_proto): assert HasAllMultiClientEnvVars() def str2int(env_config): assert env_config.isdigit() return int(env_config) bootstrap_conf = ctrl_bootstrap_pb.BootstrapConf() master_addr = ctrl_bootstrap_pb.Address() master_addr.host = os.getenv("MASTER_ADDR") master_addr.port = str2int(os.getenv("MASTER_PORT")) bootstrap_conf.master_addr.CopyFrom(master_addr) bootstrap_conf.world_size = str2int(os.getenv("WORLD_SIZE")) bootstrap_conf.rank = str2int(os.getenv("RANK")) env_proto.ctrl_bootstrap_conf.CopyFrom(bootstrap_conf) cpp_logging_conf = env_pb.CppLoggingConf() if os.getenv("GLOG_log_dir"): cpp_logging_conf.log_dir = os.getenv("GLOG_log_dir") if os.getenv("GLOG_logtostderr"): cpp_logging_conf.logtostderr = int(os.getenv("GLOG_logtostderr")) if os.getenv("GLOG_logbuflevel"): cpp_logging_conf.logbuflevel = os.getenv("GLOG_logbuflevel") env_proto.cpp_logging_conf.CopyFrom(cpp_logging_conf)